From d8c044e79d326e445bf25ebf73f4b7c68db57f25 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 10 Sep 2024 16:51:21 +0800 Subject: [PATCH] optimize minicpm3 kv cache (#12052) --- .../llm/src/ipex_llm/transformers/convert.py | 11 +- .../ipex_llm/transformers/models/minicpm3.py | 123 +++++++++++++++--- 2 files changed, 112 insertions(+), 22 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 159d3171..f8e82455 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -998,6 +998,8 @@ def _optimize_pre(model, qtype=None): if model.config.model_type == "minicpm3": from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq model.apply(pre_compute_inv_freq) + from ipex_llm.transformers.models.minicpm3 import padding_v_head_dim + model.apply(padding_v_head_dim) if model.config.model_type == "minicpmv": from ipex_llm.transformers.models.minicpmv import merge_qkv model.vpm.apply(merge_qkv) @@ -1780,7 +1782,7 @@ def _optimize_post(model, lightweight_bmm=False): elif model.config.model_type == "gemma2": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) - from ipex_llm.transformers.models.common import mlp_silu_forward + from ipex_llm.transformers.models.common import mlp_gelu_forward from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward from ipex_llm.transformers.models.gemma2 import gemma2_model_forward @@ -1789,7 +1791,7 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward) convert_forward(model, Gemma2Attention, gemma2_attention_forward) convert_forward(model, Gemma2Model, gemma2_model_forward) - convert_forward(model, Gemma2MLP, mlp_silu_forward) + convert_forward(model, Gemma2MLP, mlp_gelu_forward) elif model.config.model_type == "Yi": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) @@ -1974,10 +1976,13 @@ def _optimize_post(model, lightweight_bmm=False): module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.common import rms_norm_forward from ipex_llm.transformers.models.common import mlp_silu_forward + from ipex_llm.transformers.models.minicpm3 import minicpm3_attention_forward + from ipex_llm.transformers.models.minicpm3 import minicpm3_model_forward_wrapper convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward) convert_forward(model, module.MiniCPMMLP, mlp_silu_forward) - from ipex_llm.transformers.models.minicpm3 import minicpm3_attention_forward convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward) + minicpm3_model_forward = minicpm3_model_forward_wrapper(module.MiniCPM3Model.forward) + convert_forward(model, module.MiniCPM3Model, minicpm3_model_forward) elif model.config.model_type == "minicpmv": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm3.py b/python/llm/src/ipex_llm/transformers/models/minicpm3.py index a47b7647..820cce22 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm3.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm3.py @@ -2,12 +2,15 @@ import torch import warnings from torch import nn -from typing import Optional, Tuple +from typing import Optional, Tuple, List from transformers.cache_utils import Cache from ipex_llm.utils.common.log4Error import invalidInputError from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import rotate_half +from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache +from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache def pre_compute_inv_freq(module: torch.nn.Module): @@ -20,6 +23,72 @@ def pre_compute_inv_freq(module: torch.nn.Module): module.register_buffer("short_inv_freq", short_inv_freq, persistent=False) +def padding_v_head_dim(module: torch.nn.Module): + if module.__class__.__name__ == "MiniCPMAttention": + k_head_dim = module.qk_rope_head_dim + module.qk_nope_head_dim + v_head_dim = module.v_head_dim + invalidInputError(k_head_dim >= v_head_dim, + f"unsupported k_head_dim and v_head_dim: {k_head_dim} {v_head_dim}") + if v_head_dim < k_head_dim: + kv_b_proj = module.kv_b_proj + w = kv_b_proj.weight.data.view(module.num_heads, + module.qk_nope_head_dim + module.v_head_dim, + module.kv_lora_rank) + k_w, v_w = w.split([module.qk_nope_head_dim, module.v_head_dim], dim=1) + new_v_w = torch.zeros([module.num_heads, k_head_dim, module.kv_lora_rank], + dtype=v_w.dtype, device=v_w.device) + new_v_w[:, :v_head_dim, :] = v_w + new_w = torch.cat([k_w, new_v_w], dim=1).view(-1, module.kv_lora_rank) + + new_kv_b_proj = torch.nn.Linear(0, 0, bias=False, + dtype=new_w.dtype, device=new_w.device) + new_kv_b_proj.in_features = new_w.size(1) + new_kv_b_proj.out_features = new_w.size(0) + new_kv_b_proj.weight = torch.nn.Parameter(new_w, False) + + module.kv_b_proj = new_kv_b_proj + + +def minicpm3_model_forward_wrapper(origin_forward): + def minicpm3_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + # IPEX-LLM OPT: kv cache and quantize kv cache and sdp + inputs = input_ids if input_ids is not None else inputs_embeds + use_cache = use_cache if use_cache is not None else self.config.use_cache + use_cache = True if inputs.device.type == "xpu" else use_cache + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs) + if use_cache: + if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache): + past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) + + return origin_forward( + self=self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return minicpm3_model_forward + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): orig_dtype = k.dtype cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] @@ -62,12 +131,12 @@ def minicpm3_attention_forward( k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.q_head_dim) .transpose(1, 2) ) k_nope, value_states = torch.split( - kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + kv, [self.qk_nope_head_dim, self.q_head_dim], dim=-1 ) kv_seq_len = value_states.shape[-2] if past_key_value is not None: @@ -110,25 +179,41 @@ def minicpm3_attention_forward( key_states[:, :, :, self.qk_nope_head_dim:] = k_pe if past_key_value is not None: - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, None - ) + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, None) - attn_weights = ( - torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - ) + attn_weights = None + if use_sdp(q_len, kv_seq_len, self.q_head_dim, query_states): + import xe_addons + if isinstance(past_key_value, DynamicFp8Cache): + attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, + attention_mask) + else: + attn_output = xe_addons.sdp(query_states, key_states, value_states, + attention_mask) + attn_output = attn_output[:, :, :, :self.v_head_dim] + elif use_sdp_causal(q_len, kv_seq_len, self.q_head_dim, query_states, False): + import xe_addons + if isinstance(past_key_value, DynamicFp8Cache): + attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, + value_states, attention_mask) + else: + attn_output = xe_addons.sdp_causal(query_states, key_states, + value_states, attention_mask) + attn_output = attn_output[:, :, :, :self.v_head_dim] + else: + if isinstance(past_key_value, DynamicFp8Cache): + key_states, value_states = restore_fp8_kv_cache(key_states, value_states, + query_states.dtype) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - if attention_mask is not None: - attn_weights = attn_weights + attention_mask + if attention_mask is not None: + attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training - ) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states[:, :, :, :self.v_head_dim]) attn_output = attn_output.transpose(1, 2).contiguous()