remove unnecessary ipex kernel usage (#12649)
This commit is contained in:
		
							parent
							
								
									9f8b134889
								
							
						
					
					
						commit
						502461d836
					
				
					 2 changed files with 5 additions and 60 deletions
				
			
		| 
						 | 
				
			
			@ -1984,16 +1984,9 @@ def _optimize_post(model):
 | 
			
		|||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.yuan import yuan_attention_forward
 | 
			
		||||
        # from ipex_llm.transformers.models.yuan import yuan_mlp_forward
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.YuanAttention,
 | 
			
		||||
                        yuan_attention_forward
 | 
			
		||||
                        )
 | 
			
		||||
        # disable able mlp_forward for quantize_kv on mtl.
 | 
			
		||||
        # convert_forward(model,
 | 
			
		||||
        #                 module.YuanMLP,
 | 
			
		||||
        #                 yuan_mlp_forward
 | 
			
		||||
        #                 )
 | 
			
		||||
        convert_forward(model, module.YuanAttention, yuan_attention_forward)
 | 
			
		||||
        # from ipex_llm.transformers.models.common import mlp_silu_forward
 | 
			
		||||
        # convert_forward(model, module.YuanMLP, mlp_silu_forward)
 | 
			
		||||
    elif model.config.model_type == 'bert' and (
 | 
			
		||||
        not model.config.is_decoder and
 | 
			
		||||
        model.config.position_embedding_type == "absolute"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,17 +20,15 @@
 | 
			
		|||
# https://huggingface.co/IEITYuan/Yuan2-2B-hf/blob/7ab7b3c18eb8e5232ce2a3f720d4e6f4b53a2806/README.md#%E5%A3%B0%E6%98%8E%E4%B8%8E%E5%8D%8F%E8%AE%AEterms-and-conditions
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
 | 
			
		||||
    mlp_fusion_check, fp16_fusion_check
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import SILU, update_past_key_value
 | 
			
		||||
from ipex_llm.transformers.models.utils import update_past_key_value
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -98,52 +96,6 @@ def yuan_localized_filtering_forward(
 | 
			
		|||
    return lf_output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def yuan_mlp_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    x: torch.Tensor,
 | 
			
		||||
    residual=None
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    x_2d = x.view(-1, x.shape[-1])
 | 
			
		||||
    bsz, hidden_size = x_2d.shape
 | 
			
		||||
    qtype = getattr(self.up_proj, "qtype", None)
 | 
			
		||||
    if mlp_fusion_check(x_2d, qtype, self.training):
 | 
			
		||||
        import xe_linear
 | 
			
		||||
        if not x_2d.is_contiguous():
 | 
			
		||||
            x_2d = x_2d.contiguous()
 | 
			
		||||
        out = self.down_proj(xe_linear.mlp_forward_xpu(
 | 
			
		||||
            x_2d, self.up_proj.weight.data, self.gate_proj.weight.data,
 | 
			
		||||
            x_2d.shape[0], x_2d.shape[1], self.up_proj.out_len,
 | 
			
		||||
            SILU, qtype
 | 
			
		||||
        ))
 | 
			
		||||
        if residual is not None:
 | 
			
		||||
            return out + residual
 | 
			
		||||
        else:
 | 
			
		||||
            return out
 | 
			
		||||
    elif fp16_fusion_check(self.up_proj, x, self.training) and \
 | 
			
		||||
            hidden_size == 4096 and bsz == 1:
 | 
			
		||||
        hidden_states1 = torch.ops.torch_ipex.mm_silu(x, self.up_proj.weight)
 | 
			
		||||
        hidden_states = torch.ops.torch_ipex.mm_resmul(
 | 
			
		||||
            x, self.gate_proj.weight, hidden_states1
 | 
			
		||||
        )
 | 
			
		||||
        if residual is None:
 | 
			
		||||
            hidden_states = torch.matmul(hidden_states, self.down_proj.weight)
 | 
			
		||||
        else:
 | 
			
		||||
            attn_output = torch.addmm(
 | 
			
		||||
                residual.flatten(0, -2),
 | 
			
		||||
                hidden_states.flatten(0, -2),
 | 
			
		||||
                self.down_proj.weight,
 | 
			
		||||
                beta=1,
 | 
			
		||||
            )
 | 
			
		||||
            hidden_states = attn_output.view(x.shape)
 | 
			
		||||
        return hidden_states
 | 
			
		||||
    else:
 | 
			
		||||
        out = self.down_proj(self.act_fn(self.up_proj(x)) * self.gate_proj(x))
 | 
			
		||||
        if residual is not None:
 | 
			
		||||
            return out + residual
 | 
			
		||||
        else:
 | 
			
		||||
            return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def yuan_attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue