use fuse mlp in qwen (#9672)
This commit is contained in:
		
							parent
							
								
									c7741c4e84
								
							
						
					
					
						commit
						09ca540f9b
					
				
					 2 changed files with 18 additions and 0 deletions
				
			
		| 
						 | 
					@ -590,6 +590,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
            modeling_module_name = model.__class__.__module__
 | 
					            modeling_module_name = model.__class__.__module__
 | 
				
			||||||
            module = importlib.import_module(modeling_module_name)
 | 
					            module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
            from bigdl.llm.transformers.models.qwen import qwen_attention_forward
 | 
					            from bigdl.llm.transformers.models.qwen import qwen_attention_forward
 | 
				
			||||||
 | 
					            from bigdl.llm.transformers.models.qwen import qwen_mlp_forward
 | 
				
			||||||
            from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
 | 
					            from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
 | 
				
			||||||
            convert_forward(model,
 | 
					            convert_forward(model,
 | 
				
			||||||
                            module.QWenAttention,
 | 
					                            module.QWenAttention,
 | 
				
			||||||
| 
						 | 
					@ -598,6 +599,9 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
            convert_forward(model,
 | 
					            convert_forward(model,
 | 
				
			||||||
                            module.RMSNorm,
 | 
					                            module.RMSNorm,
 | 
				
			||||||
                            chatglm_rms_norm_forward)
 | 
					                            chatglm_rms_norm_forward)
 | 
				
			||||||
 | 
					            convert_forward(model,
 | 
				
			||||||
 | 
					                            module.QWenMLP,
 | 
				
			||||||
 | 
					                            qwen_mlp_forward)
 | 
				
			||||||
    elif model.config.model_type == "aquila":
 | 
					    elif model.config.model_type == "aquila":
 | 
				
			||||||
        modeling_module_name = model.__class__.__module__
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
        module = importlib.import_module(modeling_module_name)
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -210,3 +210,17 @@ def qwen_attention_forward(
 | 
				
			||||||
            outputs += (attn_weight,)
 | 
					            outputs += (attn_weight,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return outputs
 | 
					    return outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					    if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
 | 
				
			||||||
 | 
					            and not (self.training and x.requires_grad):
 | 
				
			||||||
 | 
					        import linear_q4_0
 | 
				
			||||||
 | 
					        x_2d = x.view(-1, x.shape[-1])
 | 
				
			||||||
 | 
					        if not x_2d.is_contiguous():
 | 
				
			||||||
 | 
					            x_2d = x_2d.contiguous()
 | 
				
			||||||
 | 
					        return self.c_proj(linear_q4_0.mlp_forward_q4_0_xpu(
 | 
				
			||||||
 | 
					            x_2d, self.w2.weight.data, self.w1.weight.data,
 | 
				
			||||||
 | 
					            x_2d.shape[0], x_2d.shape[1], self.w2.out_len,
 | 
				
			||||||
 | 
					        ))
 | 
				
			||||||
 | 
					    return self.c_proj(F.silu(self.w2(x)) * self.w1(x))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue