use mlp silu_mul fusion in qwen2 to optimize memory usage (#11574)
This commit is contained in:
parent
13a72dc51d
commit
019da6c0ab
2 changed files with 26 additions and 1 deletions
|
|
@ -1323,6 +1323,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
from ipex_llm.transformers.models.qwen2 import qwen2_model_forward
|
||||
from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
|
||||
from ipex_llm.transformers.models.qwen2 import qwen2_causal_lm_forward
|
||||
from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward
|
||||
convert_forward(model,
|
||||
module.Qwen2Model,
|
||||
qwen2_model_forward)
|
||||
|
|
@ -1334,7 +1335,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
llama_rms_norm_forward)
|
||||
convert_forward(model,
|
||||
module.Qwen2MLP,
|
||||
llama_mlp_forward)
|
||||
qwen2_mlp_forward)
|
||||
convert_forward(model,
|
||||
module.Qwen2Attention,
|
||||
qwen2_attention_forward)
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ import torch
|
|||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
||||
|
||||
from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
|
||||
|
|
@ -491,3 +492,26 @@ def qwen2_attention_forward(
|
|||
if not output_attentions:
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def qwen2_mlp_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
qtype = getattr(self.gate_proj, "qtype", None)
|
||||
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
|
||||
import xe_linear
|
||||
return self.down_proj(xe_linear.mlp_forward_xpu(
|
||||
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
||||
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
||||
SILU, qtype
|
||||
))
|
||||
elif not self.training:
|
||||
import xe_addons
|
||||
gate = self.gate_proj(x)
|
||||
up = self.up_proj(x)
|
||||
xe_addons.mlp_silu_mul_inplaced(gate, up)
|
||||
return self.down_proj(gate)
|
||||
else:
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
|
|
|||
Loading…
Reference in a new issue