use mlp silu_mul fusion in qwen2 to optimize memory usage (#11574)
This commit is contained in:
		
							parent
							
								
									13a72dc51d
								
							
						
					
					
						commit
						019da6c0ab
					
				
					 2 changed files with 26 additions and 1 deletions
				
			
		| 
						 | 
					@ -1323,6 +1323,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        from ipex_llm.transformers.models.qwen2 import qwen2_model_forward
 | 
					        from ipex_llm.transformers.models.qwen2 import qwen2_model_forward
 | 
				
			||||||
        from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
 | 
					        from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
 | 
				
			||||||
        from ipex_llm.transformers.models.qwen2 import qwen2_causal_lm_forward
 | 
					        from ipex_llm.transformers.models.qwen2 import qwen2_causal_lm_forward
 | 
				
			||||||
 | 
					        from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward
 | 
				
			||||||
        convert_forward(model,
 | 
					        convert_forward(model,
 | 
				
			||||||
                        module.Qwen2Model,
 | 
					                        module.Qwen2Model,
 | 
				
			||||||
                        qwen2_model_forward)
 | 
					                        qwen2_model_forward)
 | 
				
			||||||
| 
						 | 
					@ -1334,7 +1335,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
                        llama_rms_norm_forward)
 | 
					                        llama_rms_norm_forward)
 | 
				
			||||||
        convert_forward(model,
 | 
					        convert_forward(model,
 | 
				
			||||||
                        module.Qwen2MLP,
 | 
					                        module.Qwen2MLP,
 | 
				
			||||||
                        llama_mlp_forward)
 | 
					                        qwen2_mlp_forward)
 | 
				
			||||||
        convert_forward(model,
 | 
					        convert_forward(model,
 | 
				
			||||||
                        module.Qwen2Attention,
 | 
					                        module.Qwen2Attention,
 | 
				
			||||||
                        qwen2_attention_forward)
 | 
					                        qwen2_attention_forward)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -45,6 +45,7 @@ import torch
 | 
				
			||||||
from torch.nn import CrossEntropyLoss
 | 
					from torch.nn import CrossEntropyLoss
 | 
				
			||||||
from torch.nn.functional import scaled_dot_product_attention as sdpa
 | 
					from torch.nn.functional import scaled_dot_product_attention as sdpa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
 | 
				
			||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
					from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
					from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
 | 
				
			||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
 | 
					from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
 | 
				
			||||||
| 
						 | 
					@ -491,3 +492,26 @@ def qwen2_attention_forward(
 | 
				
			||||||
    if not output_attentions:
 | 
					    if not output_attentions:
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
    return attn_output, attn_weights, past_key_value
 | 
					    return attn_output, attn_weights, past_key_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def qwen2_mlp_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    x: torch.Tensor,
 | 
				
			||||||
 | 
					) -> torch.Tensor:
 | 
				
			||||||
 | 
					    x_2d = x.view(-1, x.shape[-1])
 | 
				
			||||||
 | 
					    qtype = getattr(self.gate_proj, "qtype", None)
 | 
				
			||||||
 | 
					    if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
 | 
				
			||||||
 | 
					        import xe_linear
 | 
				
			||||||
 | 
					        return self.down_proj(xe_linear.mlp_forward_xpu(
 | 
				
			||||||
 | 
					            x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
 | 
				
			||||||
 | 
					            x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
 | 
				
			||||||
 | 
					            SILU, qtype
 | 
				
			||||||
 | 
					        ))
 | 
				
			||||||
 | 
					    elif not self.training:
 | 
				
			||||||
 | 
					        import xe_addons
 | 
				
			||||||
 | 
					        gate = self.gate_proj(x)
 | 
				
			||||||
 | 
					        up = self.up_proj(x)
 | 
				
			||||||
 | 
					        xe_addons.mlp_silu_mul_inplaced(gate, up)
 | 
				
			||||||
 | 
					        return self.down_proj(gate)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue