From e38e29511ca05ebd7a0f51a279a320e30ce67981 Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Mon, 26 Feb 2024 15:10:08 +0800 Subject: [PATCH] [LLM] Yuan2 MLP and Rotary optimization (#10231) * Add optimization for rotary embedding * Add mlp fused optimizatgion * Python style fix * Fix rotary embedding due to logits difference * Small fix --- .../llm/src/bigdl/llm/transformers/convert.py | 5 ++ .../bigdl/llm/transformers/models/utils.py | 4 +- .../src/bigdl/llm/transformers/models/yuan.py | 69 +++++++++++++++++-- 3 files changed, 70 insertions(+), 8 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index cd7c60db..3ec171ed 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -1157,8 +1157,13 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.yuan import yuan_attention_forward + from bigdl.llm.transformers.models.yuan import yuan_mlp_forward convert_forward(model, module.YuanAttention, yuan_attention_forward ) + convert_forward(model, + module.YuanMLP, + yuan_mlp_forward + ) return model diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index c238c595..916e855b 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -182,7 +182,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family): q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device) k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral", - "mixtral", "qwen2"]: + "mixtral"]: linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids, q_embed, k_embed) return q_embed, k_embed else: @@ -199,7 +199,7 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_i k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device) if model_family in ["qwen", "mixtral"]: linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed) - elif model_family in ["qwen2"]: + elif model_family in ["qwen2", "yuan"]: cos = cos.to(q.dtype) sin = sin.to(q.dtype) cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] diff --git a/python/llm/src/bigdl/llm/transformers/models/yuan.py b/python/llm/src/bigdl/llm/transformers/models/yuan.py index ed869a0c..2419fa91 100644 --- a/python/llm/src/bigdl/llm/transformers/models/yuan.py +++ b/python/llm/src/bigdl/llm/transformers/models/yuan.py @@ -27,8 +27,10 @@ from typing import Optional, Tuple import torch import torch.nn as nn + from bigdl.llm.utils.common import invalidInputError -from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \ + apply_rotary_pos_emb_cache_freq_xpu, mlp_fusion_check, fp16_fusion_check from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31 from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5 @@ -48,6 +50,52 @@ def should_use_fuse_rope(self, hidden_states, position_ids): return use_fuse_rope +def yuan_mlp_forward( + self, + x: torch.Tensor, + residual=None +) -> torch.Tensor: + x_2d = x.view(-1, x.shape[-1]) + bsz, hidden_size = x_2d.shape + qtype = getattr(self.up_proj, "qtype", None) + if mlp_fusion_check(x_2d, qtype, self.training): + import linear_q4_0 + if not x_2d.is_contiguous(): + x_2d = x_2d.contiguous() + out = self.down_proj(linear_q4_0.mlp_forward_xpu( + x_2d, self.up_proj.weight.data, self.gate_proj.weight.data, + x_2d.shape[0], x_2d.shape[1], self.up_proj.out_len, + qtype + )) + if residual is not None: + return out + residual + else: + return out + elif fp16_fusion_check(self.up_proj, x, self.training) and \ + hidden_size == 4096 and bsz == 1: + hidden_states1 = torch.ops.torch_ipex.mm_silu(x, self.up_proj.weight) + hidden_states = torch.ops.torch_ipex.mm_resmul( + x, self.gate_proj.weight, hidden_states1 + ) + if residual is None: + hidden_states = torch.matmul(hidden_states, self.down_proj.weight) + else: + attn_output = torch.addmm( + residual.flatten(0, -2), + hidden_states.flatten(0, -2), + self.down_proj.weight, + beta=1, + ) + hidden_states = attn_output.view(x.shape) + return hidden_states + else: + out = self.down_proj(self.act_fn(self.up_proj(x)) * self.gate_proj(x)) + if residual is not None: + return out + residual + else: + return out + + def yuan_attention_forward( self, hidden_states: torch.Tensor, @@ -57,6 +105,7 @@ def yuan_attention_forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) bsz, q_len, _ = hidden_states.size() device = hidden_states.device before_hidden_states = None @@ -112,12 +161,20 @@ def yuan_attention_forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, - key_states, - cos, sin, - position_ids, - "yuan") + if use_fuse_rope: + query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, + key_states, + sin, cos, + "yuan", + position_ids) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, + key_states, + cos, sin, + position_ids, + "yuan") if past_key_value is not None: # reuse k, v, self_attention