optimize minicpm3 kv cache (#12052)
This commit is contained in:
parent
5d3ab16a80
commit
d8c044e79d
2 changed files with 112 additions and 22 deletions
|
|
@ -998,6 +998,8 @@ def _optimize_pre(model, qtype=None):
|
|||
if model.config.model_type == "minicpm3":
|
||||
from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq
|
||||
model.apply(pre_compute_inv_freq)
|
||||
from ipex_llm.transformers.models.minicpm3 import padding_v_head_dim
|
||||
model.apply(padding_v_head_dim)
|
||||
if model.config.model_type == "minicpmv":
|
||||
from ipex_llm.transformers.models.minicpmv import merge_qkv
|
||||
model.vpm.apply(merge_qkv)
|
||||
|
|
@ -1780,7 +1782,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
elif model.config.model_type == "gemma2":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.common import mlp_silu_forward
|
||||
from ipex_llm.transformers.models.common import mlp_gelu_forward
|
||||
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
|
||||
from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
|
||||
from ipex_llm.transformers.models.gemma2 import gemma2_model_forward
|
||||
|
|
@ -1789,7 +1791,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
|
||||
convert_forward(model, Gemma2Attention, gemma2_attention_forward)
|
||||
convert_forward(model, Gemma2Model, gemma2_model_forward)
|
||||
convert_forward(model, Gemma2MLP, mlp_silu_forward)
|
||||
convert_forward(model, Gemma2MLP, mlp_gelu_forward)
|
||||
elif model.config.model_type == "Yi":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
@ -1974,10 +1976,13 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||
from ipex_llm.transformers.models.common import mlp_silu_forward
|
||||
from ipex_llm.transformers.models.minicpm3 import minicpm3_attention_forward
|
||||
from ipex_llm.transformers.models.minicpm3 import minicpm3_model_forward_wrapper
|
||||
convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
|
||||
convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
|
||||
from ipex_llm.transformers.models.minicpm3 import minicpm3_attention_forward
|
||||
convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward)
|
||||
minicpm3_model_forward = minicpm3_model_forward_wrapper(module.MiniCPM3Model.forward)
|
||||
convert_forward(model, module.MiniCPM3Model, minicpm3_model_forward)
|
||||
elif model.config.model_type == "minicpmv":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
|
|||
|
|
@ -2,12 +2,15 @@ import torch
|
|||
import warnings
|
||||
|
||||
from torch import nn
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, List
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
from ipex_llm.transformers.models.utils import rotate_half
|
||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
||||
|
||||
|
||||
def pre_compute_inv_freq(module: torch.nn.Module):
|
||||
|
|
@ -20,6 +23,72 @@ def pre_compute_inv_freq(module: torch.nn.Module):
|
|||
module.register_buffer("short_inv_freq", short_inv_freq, persistent=False)
|
||||
|
||||
|
||||
def padding_v_head_dim(module: torch.nn.Module):
|
||||
if module.__class__.__name__ == "MiniCPMAttention":
|
||||
k_head_dim = module.qk_rope_head_dim + module.qk_nope_head_dim
|
||||
v_head_dim = module.v_head_dim
|
||||
invalidInputError(k_head_dim >= v_head_dim,
|
||||
f"unsupported k_head_dim and v_head_dim: {k_head_dim} {v_head_dim}")
|
||||
if v_head_dim < k_head_dim:
|
||||
kv_b_proj = module.kv_b_proj
|
||||
w = kv_b_proj.weight.data.view(module.num_heads,
|
||||
module.qk_nope_head_dim + module.v_head_dim,
|
||||
module.kv_lora_rank)
|
||||
k_w, v_w = w.split([module.qk_nope_head_dim, module.v_head_dim], dim=1)
|
||||
new_v_w = torch.zeros([module.num_heads, k_head_dim, module.kv_lora_rank],
|
||||
dtype=v_w.dtype, device=v_w.device)
|
||||
new_v_w[:, :v_head_dim, :] = v_w
|
||||
new_w = torch.cat([k_w, new_v_w], dim=1).view(-1, module.kv_lora_rank)
|
||||
|
||||
new_kv_b_proj = torch.nn.Linear(0, 0, bias=False,
|
||||
dtype=new_w.dtype, device=new_w.device)
|
||||
new_kv_b_proj.in_features = new_w.size(1)
|
||||
new_kv_b_proj.out_features = new_w.size(0)
|
||||
new_kv_b_proj.weight = torch.nn.Parameter(new_w, False)
|
||||
|
||||
module.kv_b_proj = new_kv_b_proj
|
||||
|
||||
|
||||
def minicpm3_model_forward_wrapper(origin_forward):
|
||||
def minicpm3_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
|
||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
use_cache = True if inputs.device.type == "xpu" else use_cache
|
||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
|
||||
if use_cache:
|
||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||
|
||||
return origin_forward(
|
||||
self=self,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
return minicpm3_model_forward
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
orig_dtype = k.dtype
|
||||
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
|
||||
|
|
@ -62,12 +131,12 @@ def minicpm3_attention_forward(
|
|||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
||||
kv = (
|
||||
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
||||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.q_head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
k_nope, value_states = torch.split(
|
||||
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||
kv, [self.qk_nope_head_dim, self.q_head_dim], dim=-1
|
||||
)
|
||||
kv_seq_len = value_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
|
|
@ -110,25 +179,41 @@ def minicpm3_attention_forward(
|
|||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, None
|
||||
)
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, None)
|
||||
|
||||
attn_weights = (
|
||||
torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
|
||||
)
|
||||
attn_weights = None
|
||||
if use_sdp(q_len, kv_seq_len, self.q_head_dim, query_states):
|
||||
import xe_addons
|
||||
if isinstance(past_key_value, DynamicFp8Cache):
|
||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
else:
|
||||
attn_output = xe_addons.sdp(query_states, key_states, value_states,
|
||||
attention_mask)
|
||||
attn_output = attn_output[:, :, :, :self.v_head_dim]
|
||||
elif use_sdp_causal(q_len, kv_seq_len, self.q_head_dim, query_states, False):
|
||||
import xe_addons
|
||||
if isinstance(past_key_value, DynamicFp8Cache):
|
||||
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
||||
value_states, attention_mask)
|
||||
else:
|
||||
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
||||
value_states, attention_mask)
|
||||
attn_output = attn_output[:, :, :, :self.v_head_dim]
|
||||
else:
|
||||
if isinstance(past_key_value, DynamicFp8Cache):
|
||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
query_states.dtype)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.attention_dropout, training=self.training
|
||||
)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights,
|
||||
dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states[:, :, :, :self.v_head_dim])
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue