[LLM] Yuan2 MLP and Rotary optimization (#10231)
* Add optimization for rotary embedding * Add mlp fused optimizatgion * Python style fix * Fix rotary embedding due to logits difference * Small fix
This commit is contained in:
parent
5ad752bae8
commit
e38e29511c
3 changed files with 70 additions and 8 deletions
|
|
@ -1157,8 +1157,13 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.yuan import yuan_attention_forward
|
||||
from bigdl.llm.transformers.models.yuan import yuan_mlp_forward
|
||||
convert_forward(model,
|
||||
module.YuanAttention,
|
||||
yuan_attention_forward
|
||||
)
|
||||
convert_forward(model,
|
||||
module.YuanMLP,
|
||||
yuan_mlp_forward
|
||||
)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -182,7 +182,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
|
|||
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
||||
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
||||
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
|
||||
"mixtral", "qwen2"]:
|
||||
"mixtral"]:
|
||||
linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids, q_embed, k_embed)
|
||||
return q_embed, k_embed
|
||||
else:
|
||||
|
|
@ -199,7 +199,7 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_i
|
|||
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
||||
if model_family in ["qwen", "mixtral"]:
|
||||
linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
|
||||
elif model_family in ["qwen2"]:
|
||||
elif model_family in ["qwen2", "yuan"]:
|
||||
cos = cos.to(q.dtype)
|
||||
sin = sin.to(q.dtype)
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
|
|
|
|||
|
|
@ -27,8 +27,10 @@ from typing import Optional, Tuple
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
|
||||
apply_rotary_pos_emb_cache_freq_xpu, mlp_fusion_check, fp16_fusion_check
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31
|
||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
|
||||
|
|
@ -48,6 +50,52 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
|
|||
return use_fuse_rope
|
||||
|
||||
|
||||
def yuan_mlp_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual=None
|
||||
) -> torch.Tensor:
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
bsz, hidden_size = x_2d.shape
|
||||
qtype = getattr(self.up_proj, "qtype", None)
|
||||
if mlp_fusion_check(x_2d, qtype, self.training):
|
||||
import linear_q4_0
|
||||
if not x_2d.is_contiguous():
|
||||
x_2d = x_2d.contiguous()
|
||||
out = self.down_proj(linear_q4_0.mlp_forward_xpu(
|
||||
x_2d, self.up_proj.weight.data, self.gate_proj.weight.data,
|
||||
x_2d.shape[0], x_2d.shape[1], self.up_proj.out_len,
|
||||
qtype
|
||||
))
|
||||
if residual is not None:
|
||||
return out + residual
|
||||
else:
|
||||
return out
|
||||
elif fp16_fusion_check(self.up_proj, x, self.training) and \
|
||||
hidden_size == 4096 and bsz == 1:
|
||||
hidden_states1 = torch.ops.torch_ipex.mm_silu(x, self.up_proj.weight)
|
||||
hidden_states = torch.ops.torch_ipex.mm_resmul(
|
||||
x, self.gate_proj.weight, hidden_states1
|
||||
)
|
||||
if residual is None:
|
||||
hidden_states = torch.matmul(hidden_states, self.down_proj.weight)
|
||||
else:
|
||||
attn_output = torch.addmm(
|
||||
residual.flatten(0, -2),
|
||||
hidden_states.flatten(0, -2),
|
||||
self.down_proj.weight,
|
||||
beta=1,
|
||||
)
|
||||
hidden_states = attn_output.view(x.shape)
|
||||
return hidden_states
|
||||
else:
|
||||
out = self.down_proj(self.act_fn(self.up_proj(x)) * self.gate_proj(x))
|
||||
if residual is not None:
|
||||
return out + residual
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
def yuan_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
@ -57,6 +105,7 @@ def yuan_attention_forward(
|
|||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
before_hidden_states = None
|
||||
|
|
@ -112,12 +161,20 @@ def yuan_attention_forward(
|
|||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states,
|
||||
key_states,
|
||||
cos, sin,
|
||||
position_ids,
|
||||
"yuan")
|
||||
if use_fuse_rope:
|
||||
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states,
|
||||
key_states,
|
||||
sin, cos,
|
||||
"yuan",
|
||||
position_ids)
|
||||
else:
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states,
|
||||
key_states,
|
||||
cos, sin,
|
||||
position_ids,
|
||||
"yuan")
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
|
|
|
|||
Loading…
Reference in a new issue