optimize minicpm3 kv cache (#12052)

This commit is contained in:
Yishuo Wang 2024-09-10 16:51:21 +08:00 committed by GitHub
parent 5d3ab16a80
commit d8c044e79d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 112 additions and 22 deletions

View file

@ -998,6 +998,8 @@ def _optimize_pre(model, qtype=None):
if model.config.model_type == "minicpm3":
from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq
model.apply(pre_compute_inv_freq)
from ipex_llm.transformers.models.minicpm3 import padding_v_head_dim
model.apply(padding_v_head_dim)
if model.config.model_type == "minicpmv":
from ipex_llm.transformers.models.minicpmv import merge_qkv
model.vpm.apply(merge_qkv)
@ -1780,7 +1782,7 @@ def _optimize_post(model, lightweight_bmm=False):
elif model.config.model_type == "gemma2":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.common import mlp_silu_forward
from ipex_llm.transformers.models.common import mlp_gelu_forward
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
from ipex_llm.transformers.models.gemma2 import gemma2_model_forward
@ -1789,7 +1791,7 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
convert_forward(model, Gemma2Attention, gemma2_attention_forward)
convert_forward(model, Gemma2Model, gemma2_model_forward)
convert_forward(model, Gemma2MLP, mlp_silu_forward)
convert_forward(model, Gemma2MLP, mlp_gelu_forward)
elif model.config.model_type == "Yi":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
@ -1974,10 +1976,13 @@ def _optimize_post(model, lightweight_bmm=False):
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.common import rms_norm_forward
from ipex_llm.transformers.models.common import mlp_silu_forward
from ipex_llm.transformers.models.minicpm3 import minicpm3_attention_forward
from ipex_llm.transformers.models.minicpm3 import minicpm3_model_forward_wrapper
convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
from ipex_llm.transformers.models.minicpm3 import minicpm3_attention_forward
convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward)
minicpm3_model_forward = minicpm3_model_forward_wrapper(module.MiniCPM3Model.forward)
convert_forward(model, module.MiniCPM3Model, minicpm3_model_forward)
elif model.config.model_type == "minicpmv":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)

View file

@ -2,12 +2,15 @@ import torch
import warnings
from torch import nn
from typing import Optional, Tuple
from typing import Optional, Tuple, List
from transformers.cache_utils import Cache
from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import rotate_half
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
def pre_compute_inv_freq(module: torch.nn.Module):
@ -20,6 +23,72 @@ def pre_compute_inv_freq(module: torch.nn.Module):
module.register_buffer("short_inv_freq", short_inv_freq, persistent=False)
def padding_v_head_dim(module: torch.nn.Module):
if module.__class__.__name__ == "MiniCPMAttention":
k_head_dim = module.qk_rope_head_dim + module.qk_nope_head_dim
v_head_dim = module.v_head_dim
invalidInputError(k_head_dim >= v_head_dim,
f"unsupported k_head_dim and v_head_dim: {k_head_dim} {v_head_dim}")
if v_head_dim < k_head_dim:
kv_b_proj = module.kv_b_proj
w = kv_b_proj.weight.data.view(module.num_heads,
module.qk_nope_head_dim + module.v_head_dim,
module.kv_lora_rank)
k_w, v_w = w.split([module.qk_nope_head_dim, module.v_head_dim], dim=1)
new_v_w = torch.zeros([module.num_heads, k_head_dim, module.kv_lora_rank],
dtype=v_w.dtype, device=v_w.device)
new_v_w[:, :v_head_dim, :] = v_w
new_w = torch.cat([k_w, new_v_w], dim=1).view(-1, module.kv_lora_rank)
new_kv_b_proj = torch.nn.Linear(0, 0, bias=False,
dtype=new_w.dtype, device=new_w.device)
new_kv_b_proj.in_features = new_w.size(1)
new_kv_b_proj.out_features = new_w.size(0)
new_kv_b_proj.weight = torch.nn.Parameter(new_w, False)
module.kv_b_proj = new_kv_b_proj
def minicpm3_model_forward_wrapper(origin_forward):
def minicpm3_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
inputs = input_ids if input_ids is not None else inputs_embeds
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
return origin_forward(
self=self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return minicpm3_model_forward
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
orig_dtype = k.dtype
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
@ -62,12 +131,12 @@ def minicpm3_attention_forward(
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.q_head_dim)
.transpose(1, 2)
)
k_nope, value_states = torch.split(
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
kv, [self.qk_nope_head_dim, self.q_head_dim], dim=-1
)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
@ -110,25 +179,41 @@ def minicpm3_attention_forward(
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
if past_key_value is not None:
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, None
)
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
)
attn_weights = None
if use_sdp(q_len, kv_seq_len, self.q_head_dim, query_states):
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
attn_output = attn_output[:, :, :, :self.v_head_dim]
elif use_sdp_causal(q_len, kv_seq_len, self.q_head_dim, query_states, False):
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
attn_output = attn_output[:, :, :, :self.v_head_dim]
else:
if isinstance(past_key_value, DynamicFp8Cache):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states[:, :, :, :self.v_head_dim])
attn_output = attn_output.transpose(1, 2).contiguous()