use fuse mlp in qwen (#9672)
This commit is contained in:
parent
c7741c4e84
commit
09ca540f9b
2 changed files with 18 additions and 0 deletions
|
|
@ -590,6 +590,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.qwen import qwen_attention_forward
|
from bigdl.llm.transformers.models.qwen import qwen_attention_forward
|
||||||
|
from bigdl.llm.transformers.models.qwen import qwen_mlp_forward
|
||||||
from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.QWenAttention,
|
module.QWenAttention,
|
||||||
|
|
@ -598,6 +599,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.RMSNorm,
|
module.RMSNorm,
|
||||||
chatglm_rms_norm_forward)
|
chatglm_rms_norm_forward)
|
||||||
|
convert_forward(model,
|
||||||
|
module.QWenMLP,
|
||||||
|
qwen_mlp_forward)
|
||||||
elif model.config.model_type == "aquila":
|
elif model.config.model_type == "aquila":
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
|
|
||||||
|
|
@ -210,3 +210,17 @@ def qwen_attention_forward(
|
||||||
outputs += (attn_weight,)
|
outputs += (attn_weight,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
|
||||||
|
and not (self.training and x.requires_grad):
|
||||||
|
import linear_q4_0
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
if not x_2d.is_contiguous():
|
||||||
|
x_2d = x_2d.contiguous()
|
||||||
|
return self.c_proj(linear_q4_0.mlp_forward_q4_0_xpu(
|
||||||
|
x_2d, self.w2.weight.data, self.w1.weight.data,
|
||||||
|
x_2d.shape[0], x_2d.shape[1], self.w2.out_len,
|
||||||
|
))
|
||||||
|
return self.c_proj(F.silu(self.w2(x)) * self.w1(x))
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue