refactor to simplify following upgrade (#12680)
This commit is contained in:
		
							parent
							
								
									aa9e70a347
								
							
						
					
					
						commit
						1ec40cd09e
					
				
					 5 changed files with 10 additions and 87 deletions
				
			
		| 
						 | 
				
			
			@ -1325,7 +1325,6 @@ def _optimize_post(model):
 | 
			
		|||
            modeling_module_name = model.__class__.__module__
 | 
			
		||||
            module = importlib.import_module(modeling_module_name)
 | 
			
		||||
            from ipex_llm.transformers.models.chatglm2 import chatglm2_attention_forward
 | 
			
		||||
            from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
 | 
			
		||||
            from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward
 | 
			
		||||
            from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward
 | 
			
		||||
            from ipex_llm.transformers.models.chatglm2 import mlp_forward
 | 
			
		||||
| 
						 | 
				
			
			@ -1338,9 +1337,7 @@ def _optimize_post(model):
 | 
			
		|||
            convert_forward(model,
 | 
			
		||||
                            module.ChatGLMModel,
 | 
			
		||||
                            chatglm2_model_forward)
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.RMSNorm,
 | 
			
		||||
                            chatglm_rms_norm_forward)
 | 
			
		||||
            convert_forward(model, module.RMSNorm, rms_norm_forward)
 | 
			
		||||
            convert_forward(model, module.MLP, mlp_forward)
 | 
			
		||||
            # for codegeex-nano
 | 
			
		||||
            if hasattr(model.config, "rope_ratio"):
 | 
			
		||||
| 
						 | 
				
			
			@ -1358,8 +1355,7 @@ def _optimize_post(model):
 | 
			
		|||
            # glm4 family
 | 
			
		||||
            modeling_module_name = model.__class__.__module__
 | 
			
		||||
            module = importlib.import_module(modeling_module_name)
 | 
			
		||||
            from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
 | 
			
		||||
            convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward)
 | 
			
		||||
            convert_forward(model, module.RMSNorm, rms_norm_forward)
 | 
			
		||||
 | 
			
		||||
            if hasattr(model.transformer, "vision"):
 | 
			
		||||
                # glm4 vision family
 | 
			
		||||
| 
						 | 
				
			
			@ -1448,8 +1444,8 @@ def _optimize_post(model):
 | 
			
		|||
    elif model.config.model_type == "baichuan":
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.baichuan import baichuan_mlp_forward
 | 
			
		||||
        convert_forward(model, module.MLP, baichuan_mlp_forward)
 | 
			
		||||
        convert_forward(model, module.RMSNorm, rms_norm_forward)
 | 
			
		||||
        convert_forward(model, module.MLP, mlp_silu_forward)
 | 
			
		||||
 | 
			
		||||
        if model.config.hidden_size in [4096, 2048]:
 | 
			
		||||
            # baichuan-7B and baichuan2-7B
 | 
			
		||||
| 
						 | 
				
			
			@ -1458,7 +1454,6 @@ def _optimize_post(model):
 | 
			
		|||
            for i in range(len(model.model.layers)):
 | 
			
		||||
                setattr(model.model.layers[i].self_attn, "layer_idx", i)
 | 
			
		||||
            convert_forward(model, module.Attention, baichuan_attention_forward_7b)
 | 
			
		||||
            convert_forward(model, module.RMSNorm, rms_norm_forward)
 | 
			
		||||
            if model.config.vocab_size == 125696:
 | 
			
		||||
                # baichuan2-7B
 | 
			
		||||
                convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward)
 | 
			
		||||
| 
						 | 
				
			
			@ -1468,9 +1463,7 @@ def _optimize_post(model):
 | 
			
		|||
        elif model.config.hidden_size == 5120:
 | 
			
		||||
            # baichuan-13B and baichuan2-13B
 | 
			
		||||
            from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b
 | 
			
		||||
            from ipex_llm.transformers.models.baichuan import baichuan_13b_rms_norm_forward
 | 
			
		||||
            convert_forward(model, module.BaichuanAttention, baichuan_attention_forward_13b)
 | 
			
		||||
            convert_forward(model, module.RMSNorm, baichuan_13b_rms_norm_forward)
 | 
			
		||||
 | 
			
		||||
            if model.config.vocab_size == 125696:
 | 
			
		||||
                # baichaun2-13B
 | 
			
		||||
| 
						 | 
				
			
			@ -1565,7 +1558,6 @@ def _optimize_post(model):
 | 
			
		|||
            from ipex_llm.transformers.models.qwen import qwen_attention_forward
 | 
			
		||||
            from ipex_llm.transformers.models.qwen import qwen_attention_forward_registered
 | 
			
		||||
            from ipex_llm.transformers.models.qwen import qwen_mlp_forward
 | 
			
		||||
            from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
 | 
			
		||||
            from ipex_llm.transformers.models.qwen import qwen_model_forward
 | 
			
		||||
            if model.config.max_position_embeddings == 8192 \
 | 
			
		||||
               and model.config.hidden_size == 4096:
 | 
			
		||||
| 
						 | 
				
			
			@ -1580,7 +1572,7 @@ def _optimize_post(model):
 | 
			
		|||
                                )
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.RMSNorm,
 | 
			
		||||
                            chatglm_rms_norm_forward)
 | 
			
		||||
                            rms_norm_forward)
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.QWenMLP,
 | 
			
		||||
                            qwen_mlp_forward)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -47,38 +47,6 @@ def pre_compute_inv_freq(module: torch.nn.Module):
 | 
			
		|||
        module.register_buffer("inv_freq", inv_freq, persistent=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_13b_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training or hidden_states.requires_grad):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
			
		||||
        output = xe_addons.rms_norm(self.weight, x_2d, self.epsilon)
 | 
			
		||||
        return output.reshape(hidden_states.shape)
 | 
			
		||||
 | 
			
		||||
    input_dtype = hidden_states.dtype
 | 
			
		||||
    hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
    hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
 | 
			
		||||
    return self.weight * hidden_states.to(input_dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_mlp_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    x: torch.Tensor,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    x_2d = x.view(-1, x.shape[-1])
 | 
			
		||||
    qtype = getattr(self.gate_proj, "qtype", None)
 | 
			
		||||
    if mlp_fusion_check(x_2d, qtype, self.training):
 | 
			
		||||
        import xe_linear
 | 
			
		||||
        if not x_2d.is_contiguous():
 | 
			
		||||
            x_2d = x_2d.contiguous()
 | 
			
		||||
        return self.down_proj(xe_linear.mlp_forward_xpu(
 | 
			
		||||
            x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
 | 
			
		||||
            x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
 | 
			
		||||
            SILU, qtype
 | 
			
		||||
        ))
 | 
			
		||||
    return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def baichuan_model_7b_forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -36,24 +36,13 @@ import math
 | 
			
		|||
import torch
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
from transformers.models.bert.modeling_bert import BertSelfAttention, BertEncoder
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_linear
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def merge_qkv(module: torch.nn.Module):
 | 
			
		||||
    if isinstance(module, BertSelfAttention):
 | 
			
		||||
        q_w = module.query.weight.data
 | 
			
		||||
        k_w = module.key.weight.data
 | 
			
		||||
        v_w = module.value.weight.data
 | 
			
		||||
        q_b = module.query.bias.data
 | 
			
		||||
        k_b = module.key.bias.data
 | 
			
		||||
        v_b = module.value.bias.data
 | 
			
		||||
        new_w = torch.cat([q_w, k_w, v_w], dim=0)
 | 
			
		||||
        new_b = torch.cat([q_b, k_b, v_b], dim=-1)
 | 
			
		||||
        qkv = torch.nn.Linear(0, 0, bias=True)
 | 
			
		||||
        qkv.weight = torch.nn.Parameter(new_w, requires_grad=False)
 | 
			
		||||
        qkv.bias = torch.nn.Parameter(new_b, requires_grad=False)
 | 
			
		||||
        qkv.in_features = module.query.in_features
 | 
			
		||||
        qkv.out_features = module.query.out_features * 3
 | 
			
		||||
        qkv = merge_linear([module.query, module.key, module.value])
 | 
			
		||||
        module.qkv = qkv
 | 
			
		||||
        del module.query
 | 
			
		||||
        del module.key
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -33,34 +33,6 @@ from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cac
 | 
			
		|||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states
 | 
			
		||||
    go from (batch, num_key_value_heads, seqlen, head_dim) to
 | 
			
		||||
    (batch, num_attention_heads, seqlen, head_dim)
 | 
			
		||||
    """
 | 
			
		||||
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
 | 
			
		||||
    if n_rep == 1:
 | 
			
		||||
        return hidden_states
 | 
			
		||||
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
 | 
			
		||||
                                                           n_rep, slen, head_dim)
 | 
			
		||||
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
			
		||||
        output = xe_addons.rms_norm(self.weight, x_2d, self.eps)
 | 
			
		||||
        return output.reshape(hidden_states.shape)
 | 
			
		||||
 | 
			
		||||
    input_dtype = hidden_states.dtype
 | 
			
		||||
    hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
    hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
 | 
			
		||||
    return self.weight * hidden_states.to(input_dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm2_model_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    input_ids,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -157,8 +157,10 @@ def rms_norm_forward(self, hidden_states: torch.Tensor):
 | 
			
		|||
    weight = self.weight
 | 
			
		||||
    if hasattr(self, "variance_epsilon"):
 | 
			
		||||
        eps = self.variance_epsilon
 | 
			
		||||
    else:
 | 
			
		||||
    elif hasattr(self, "epsilon"):
 | 
			
		||||
        eps = self.epsilon
 | 
			
		||||
    else:
 | 
			
		||||
        eps = self.eps
 | 
			
		||||
 | 
			
		||||
    if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
 | 
			
		||||
        import xe_addons
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue