[LLM] Yuan2 MLP and Rotary optimization (#10231)

* Add optimization for rotary embedding

* Add mlp fused optimizatgion

* Python style fix

* Fix rotary embedding due to logits difference

* Small fix
This commit is contained in:
Yuwen Hu 2024-02-26 15:10:08 +08:00 committed by GitHub
parent 5ad752bae8
commit e38e29511c
3 changed files with 70 additions and 8 deletions

View file

@ -1157,8 +1157,13 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.yuan import yuan_attention_forward from bigdl.llm.transformers.models.yuan import yuan_attention_forward
from bigdl.llm.transformers.models.yuan import yuan_mlp_forward
convert_forward(model, convert_forward(model,
module.YuanAttention, module.YuanAttention,
yuan_attention_forward yuan_attention_forward
) )
convert_forward(model,
module.YuanMLP,
yuan_mlp_forward
)
return model return model

View file

@ -182,7 +182,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family):
q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device) q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral", if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
"mixtral", "qwen2"]: "mixtral"]:
linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids, q_embed, k_embed) linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids, q_embed, k_embed)
return q_embed, k_embed return q_embed, k_embed
else: else:
@ -199,7 +199,7 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_i
k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
if model_family in ["qwen", "mixtral"]: if model_family in ["qwen", "mixtral"]:
linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed) linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
elif model_family in ["qwen2"]: elif model_family in ["qwen2", "yuan"]:
cos = cos.to(q.dtype) cos = cos.to(q.dtype)
sin = sin.to(q.dtype) sin = sin.to(q.dtype)
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]

View file

@ -27,8 +27,10 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
apply_rotary_pos_emb_cache_freq_xpu, mlp_fusion_check, fp16_fusion_check
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31 from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5 from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
@ -48,6 +50,52 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
return use_fuse_rope return use_fuse_rope
def yuan_mlp_forward(
self,
x: torch.Tensor,
residual=None
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
bsz, hidden_size = x_2d.shape
qtype = getattr(self.up_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training):
import linear_q4_0
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
out = self.down_proj(linear_q4_0.mlp_forward_xpu(
x_2d, self.up_proj.weight.data, self.gate_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.up_proj.out_len,
qtype
))
if residual is not None:
return out + residual
else:
return out
elif fp16_fusion_check(self.up_proj, x, self.training) and \
hidden_size == 4096 and bsz == 1:
hidden_states1 = torch.ops.torch_ipex.mm_silu(x, self.up_proj.weight)
hidden_states = torch.ops.torch_ipex.mm_resmul(
x, self.gate_proj.weight, hidden_states1
)
if residual is None:
hidden_states = torch.matmul(hidden_states, self.down_proj.weight)
else:
attn_output = torch.addmm(
residual.flatten(0, -2),
hidden_states.flatten(0, -2),
self.down_proj.weight,
beta=1,
)
hidden_states = attn_output.view(x.shape)
return hidden_states
else:
out = self.down_proj(self.act_fn(self.up_proj(x)) * self.gate_proj(x))
if residual is not None:
return out + residual
else:
return out
def yuan_attention_forward( def yuan_attention_forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -57,6 +105,7 @@ def yuan_attention_forward(
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
device = hidden_states.device device = hidden_states.device
before_hidden_states = None before_hidden_states = None
@ -112,12 +161,20 @@ def yuan_attention_forward(
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, if use_fuse_rope:
key_states, query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states,
cos, sin, key_states,
position_ids, sin, cos,
"yuan") "yuan",
position_ids)
else:
query_states, key_states = apply_rotary_pos_emb(query_states,
key_states,
cos, sin,
position_ids,
"yuan")
if past_key_value is not None: if past_key_value is not None:
# reuse k, v, self_attention # reuse k, v, self_attention