small fix (#12493)
This commit is contained in:
		
							parent
							
								
									7ff4533b39
								
							
						
					
					
						commit
						e0bf0054e1
					
				
					 2 changed files with 5 additions and 7 deletions
				
			
		| 
						 | 
				
			
			@ -116,18 +116,18 @@ def padding_qkv_hd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
 | 
			
		|||
 | 
			
		||||
def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
 | 
			
		||||
    from ipex_llm.transformers.models.utils import mlp_fusion_check
 | 
			
		||||
    x_2d = x.view(-1, x.size(-1))
 | 
			
		||||
    qtype = getattr(module.gate_proj, "qtype", None)
 | 
			
		||||
    if mlp_fusion_check(x_2d, qtype, module.training):
 | 
			
		||||
    if mlp_fusion_check(x, qtype, module.training):
 | 
			
		||||
        import xe_linear
 | 
			
		||||
        x_2d = x_2d.contiguous()
 | 
			
		||||
        return module.down_proj(
 | 
			
		||||
        x_2d = x.contiguous().view(-1, x.size(-1))
 | 
			
		||||
        output = module.down_proj(
 | 
			
		||||
            xe_linear.mlp_forward_xpu(
 | 
			
		||||
                x_2d, module.gate_proj.weight.data, module.up_proj.weight.data,
 | 
			
		||||
                x_2d.size(0), x_2d.size(1), module.gate_proj.out_len,
 | 
			
		||||
                act, qtype
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        return output.view(x.shape)
 | 
			
		||||
    else:
 | 
			
		||||
        return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -336,9 +336,7 @@ def use_sdp_non_causal(head_dim, device, dtype):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def mlp_fusion_check(x, qtype, training):
 | 
			
		||||
    invalidInputError(x.dim() == 2,
 | 
			
		||||
                      "Here input x's dim should be 2.")
 | 
			
		||||
    if x.shape[0] != 1:
 | 
			
		||||
    if x.numel() // x.size(-1) != 1:
 | 
			
		||||
        return False
 | 
			
		||||
    if x.device.type != 'xpu':
 | 
			
		||||
        return False
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue