diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 2f364c1e..7201cc32 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -590,6 +590,7 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.qwen import qwen_attention_forward + from bigdl.llm.transformers.models.qwen import qwen_mlp_forward from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward convert_forward(model, module.QWenAttention, @@ -598,6 +599,9 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward) + convert_forward(model, + module.QWenMLP, + qwen_mlp_forward) elif model.config.model_type == "aquila": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index 549d137d..18642434 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -210,3 +210,17 @@ def qwen_attention_forward( outputs += (attn_weight,) return outputs + + +def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: + if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \ + and not (self.training and x.requires_grad): + import linear_q4_0 + x_2d = x.view(-1, x.shape[-1]) + if not x_2d.is_contiguous(): + x_2d = x_2d.contiguous() + return self.c_proj(linear_q4_0.mlp_forward_q4_0_xpu( + x_2d, self.w2.weight.data, self.w1.weight.data, + x_2d.shape[0], x_2d.shape[1], self.w2.out_len, + )) + return self.c_proj(F.silu(self.w2(x)) * self.w1(x))