refactor to simplify following upgrade (#12680)

This commit is contained in:
Yishuo Wang 2025-01-09 13:34:30 +08:00 committed by GitHub
parent aa9e70a347
commit 1ec40cd09e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 10 additions and 87 deletions

View file

@ -1325,7 +1325,6 @@ def _optimize_post(model):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.chatglm2 import chatglm2_attention_forward
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward
from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward
from ipex_llm.transformers.models.chatglm2 import mlp_forward
@ -1338,9 +1337,7 @@ def _optimize_post(model):
convert_forward(model,
module.ChatGLMModel,
chatglm2_model_forward)
convert_forward(model,
module.RMSNorm,
chatglm_rms_norm_forward)
convert_forward(model, module.RMSNorm, rms_norm_forward)
convert_forward(model, module.MLP, mlp_forward)
# for codegeex-nano
if hasattr(model.config, "rope_ratio"):
@ -1358,8 +1355,7 @@ def _optimize_post(model):
# glm4 family
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward)
convert_forward(model, module.RMSNorm, rms_norm_forward)
if hasattr(model.transformer, "vision"):
# glm4 vision family
@ -1448,8 +1444,8 @@ def _optimize_post(model):
elif model.config.model_type == "baichuan":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.baichuan import baichuan_mlp_forward
convert_forward(model, module.MLP, baichuan_mlp_forward)
convert_forward(model, module.RMSNorm, rms_norm_forward)
convert_forward(model, module.MLP, mlp_silu_forward)
if model.config.hidden_size in [4096, 2048]:
# baichuan-7B and baichuan2-7B
@ -1458,7 +1454,6 @@ def _optimize_post(model):
for i in range(len(model.model.layers)):
setattr(model.model.layers[i].self_attn, "layer_idx", i)
convert_forward(model, module.Attention, baichuan_attention_forward_7b)
convert_forward(model, module.RMSNorm, rms_norm_forward)
if model.config.vocab_size == 125696:
# baichuan2-7B
convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward)
@ -1468,9 +1463,7 @@ def _optimize_post(model):
elif model.config.hidden_size == 5120:
# baichuan-13B and baichuan2-13B
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b
from ipex_llm.transformers.models.baichuan import baichuan_13b_rms_norm_forward
convert_forward(model, module.BaichuanAttention, baichuan_attention_forward_13b)
convert_forward(model, module.RMSNorm, baichuan_13b_rms_norm_forward)
if model.config.vocab_size == 125696:
# baichaun2-13B
@ -1565,7 +1558,6 @@ def _optimize_post(model):
from ipex_llm.transformers.models.qwen import qwen_attention_forward
from ipex_llm.transformers.models.qwen import qwen_attention_forward_registered
from ipex_llm.transformers.models.qwen import qwen_mlp_forward
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
from ipex_llm.transformers.models.qwen import qwen_model_forward
if model.config.max_position_embeddings == 8192 \
and model.config.hidden_size == 4096:
@ -1580,7 +1572,7 @@ def _optimize_post(model):
)
convert_forward(model,
module.RMSNorm,
chatglm_rms_norm_forward)
rms_norm_forward)
convert_forward(model,
module.QWenMLP,
qwen_mlp_forward)

View file

@ -47,38 +47,6 @@ def pre_compute_inv_freq(module: torch.nn.Module):
module.register_buffer("inv_freq", inv_freq, persistent=False)
def baichuan_13b_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training or hidden_states.requires_grad):
import xe_addons
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = xe_addons.rms_norm(self.weight, x_2d, self.epsilon)
return output.reshape(hidden_states.shape)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
return self.weight * hidden_states.to(input_dtype)
def baichuan_mlp_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training):
import xe_linear
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
return self.down_proj(xe_linear.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
SILU, qtype
))
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def baichuan_model_7b_forward(
self,
input_ids: torch.LongTensor = None,

View file

@ -36,24 +36,13 @@ import math
import torch
from typing import Optional, Tuple
from transformers.models.bert.modeling_bert import BertSelfAttention, BertEncoder
from ipex_llm.transformers.models.common import merge_linear
from ipex_llm.utils.common import invalidInputError
def merge_qkv(module: torch.nn.Module):
if isinstance(module, BertSelfAttention):
q_w = module.query.weight.data
k_w = module.key.weight.data
v_w = module.value.weight.data
q_b = module.query.bias.data
k_b = module.key.bias.data
v_b = module.value.bias.data
new_w = torch.cat([q_w, k_w, v_w], dim=0)
new_b = torch.cat([q_b, k_b, v_b], dim=-1)
qkv = torch.nn.Linear(0, 0, bias=True)
qkv.weight = torch.nn.Parameter(new_w, requires_grad=False)
qkv.bias = torch.nn.Parameter(new_b, requires_grad=False)
qkv.in_features = module.query.in_features
qkv.out_features = module.query.out_features * 3
qkv = merge_linear([module.query, module.key, module.value])
module.qkv = qkv
del module.query
del module.key

View file

@ -33,34 +33,6 @@ from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cac
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states
go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def chatglm_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import xe_addons
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = xe_addons.rms_norm(self.weight, x_2d, self.eps)
return output.reshape(hidden_states.shape)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return self.weight * hidden_states.to(input_dtype)
def chatglm2_model_forward(
self,
input_ids,

View file

@ -157,8 +157,10 @@ def rms_norm_forward(self, hidden_states: torch.Tensor):
weight = self.weight
if hasattr(self, "variance_epsilon"):
eps = self.variance_epsilon
else:
elif hasattr(self, "epsilon"):
eps = self.epsilon
else:
eps = self.eps
if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
import xe_addons