diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 2df8811a..6d1c2d5e 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -485,6 +485,7 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_7b + from bigdl.llm.transformers.models.baichuan2 import baichuan_mlp_forward convert_forward(model, module.Attention, baichuan_attention_forward_7b @@ -492,12 +493,16 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.RMSNorm, llama_rms_norm_forward) + convert_forward(model, + module.MLP, + baichuan_mlp_forward) elif model.config.hidden_size == 5120: # baichuan2-13B modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward_13b from bigdl.llm.transformers.models.baichuan2 import baichuan_13b_rms_norm_forward + from bigdl.llm.transformers.models.baichuan2 import baichuan_mlp_forward convert_forward(model, module.BaichuanAttention, baichuan_attention_forward_13b @@ -506,6 +511,9 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.RMSNorm, baichuan_13b_rms_norm_forward) + convert_forward(model, + module.MLP, + baichuan_mlp_forward) elif model.config.model_type == "baichuan": # baichuan1 if model.config.hidden_size == 4096: diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index bf1add5f..39c4a2f0 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -70,6 +70,23 @@ def baichuan_13b_rms_norm_forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) +def baichuan_mlp_forward( + self, + x: torch.Tensor, +) -> torch.Tensor: + if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \ + and not (self.training and x.requires_grad): + import linear_q4_0 + x_2d = x.view(-1, x.shape[-1]) + if not x_2d.is_contiguous(): + x_2d = x_2d.contiguous() + return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu( + x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, + x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, + )) + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + def baichuan_attention_forward_7b( self, hidden_states: torch.Tensor,