refactor to simplify following upgrade (#12680)
This commit is contained in:
parent
aa9e70a347
commit
1ec40cd09e
5 changed files with 10 additions and 87 deletions
|
|
@ -1325,7 +1325,6 @@ def _optimize_post(model):
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.chatglm2 import chatglm2_attention_forward
|
from ipex_llm.transformers.models.chatglm2 import chatglm2_attention_forward
|
||||||
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
|
||||||
from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward
|
from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward
|
||||||
from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward
|
from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward
|
||||||
from ipex_llm.transformers.models.chatglm2 import mlp_forward
|
from ipex_llm.transformers.models.chatglm2 import mlp_forward
|
||||||
|
|
@ -1338,9 +1337,7 @@ def _optimize_post(model):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.ChatGLMModel,
|
module.ChatGLMModel,
|
||||||
chatglm2_model_forward)
|
chatglm2_model_forward)
|
||||||
convert_forward(model,
|
convert_forward(model, module.RMSNorm, rms_norm_forward)
|
||||||
module.RMSNorm,
|
|
||||||
chatglm_rms_norm_forward)
|
|
||||||
convert_forward(model, module.MLP, mlp_forward)
|
convert_forward(model, module.MLP, mlp_forward)
|
||||||
# for codegeex-nano
|
# for codegeex-nano
|
||||||
if hasattr(model.config, "rope_ratio"):
|
if hasattr(model.config, "rope_ratio"):
|
||||||
|
|
@ -1358,8 +1355,7 @@ def _optimize_post(model):
|
||||||
# glm4 family
|
# glm4 family
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
convert_forward(model, module.RMSNorm, rms_norm_forward)
|
||||||
convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward)
|
|
||||||
|
|
||||||
if hasattr(model.transformer, "vision"):
|
if hasattr(model.transformer, "vision"):
|
||||||
# glm4 vision family
|
# glm4 vision family
|
||||||
|
|
@ -1448,8 +1444,8 @@ def _optimize_post(model):
|
||||||
elif model.config.model_type == "baichuan":
|
elif model.config.model_type == "baichuan":
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.baichuan import baichuan_mlp_forward
|
convert_forward(model, module.RMSNorm, rms_norm_forward)
|
||||||
convert_forward(model, module.MLP, baichuan_mlp_forward)
|
convert_forward(model, module.MLP, mlp_silu_forward)
|
||||||
|
|
||||||
if model.config.hidden_size in [4096, 2048]:
|
if model.config.hidden_size in [4096, 2048]:
|
||||||
# baichuan-7B and baichuan2-7B
|
# baichuan-7B and baichuan2-7B
|
||||||
|
|
@ -1458,7 +1454,6 @@ def _optimize_post(model):
|
||||||
for i in range(len(model.model.layers)):
|
for i in range(len(model.model.layers)):
|
||||||
setattr(model.model.layers[i].self_attn, "layer_idx", i)
|
setattr(model.model.layers[i].self_attn, "layer_idx", i)
|
||||||
convert_forward(model, module.Attention, baichuan_attention_forward_7b)
|
convert_forward(model, module.Attention, baichuan_attention_forward_7b)
|
||||||
convert_forward(model, module.RMSNorm, rms_norm_forward)
|
|
||||||
if model.config.vocab_size == 125696:
|
if model.config.vocab_size == 125696:
|
||||||
# baichuan2-7B
|
# baichuan2-7B
|
||||||
convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward)
|
convert_forward(model, module.BaichuanModel, baichuan_model_7b_forward)
|
||||||
|
|
@ -1468,9 +1463,7 @@ def _optimize_post(model):
|
||||||
elif model.config.hidden_size == 5120:
|
elif model.config.hidden_size == 5120:
|
||||||
# baichuan-13B and baichuan2-13B
|
# baichuan-13B and baichuan2-13B
|
||||||
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b
|
from ipex_llm.transformers.models.baichuan import baichuan_attention_forward_13b
|
||||||
from ipex_llm.transformers.models.baichuan import baichuan_13b_rms_norm_forward
|
|
||||||
convert_forward(model, module.BaichuanAttention, baichuan_attention_forward_13b)
|
convert_forward(model, module.BaichuanAttention, baichuan_attention_forward_13b)
|
||||||
convert_forward(model, module.RMSNorm, baichuan_13b_rms_norm_forward)
|
|
||||||
|
|
||||||
if model.config.vocab_size == 125696:
|
if model.config.vocab_size == 125696:
|
||||||
# baichaun2-13B
|
# baichaun2-13B
|
||||||
|
|
@ -1565,7 +1558,6 @@ def _optimize_post(model):
|
||||||
from ipex_llm.transformers.models.qwen import qwen_attention_forward
|
from ipex_llm.transformers.models.qwen import qwen_attention_forward
|
||||||
from ipex_llm.transformers.models.qwen import qwen_attention_forward_registered
|
from ipex_llm.transformers.models.qwen import qwen_attention_forward_registered
|
||||||
from ipex_llm.transformers.models.qwen import qwen_mlp_forward
|
from ipex_llm.transformers.models.qwen import qwen_mlp_forward
|
||||||
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
|
||||||
from ipex_llm.transformers.models.qwen import qwen_model_forward
|
from ipex_llm.transformers.models.qwen import qwen_model_forward
|
||||||
if model.config.max_position_embeddings == 8192 \
|
if model.config.max_position_embeddings == 8192 \
|
||||||
and model.config.hidden_size == 4096:
|
and model.config.hidden_size == 4096:
|
||||||
|
|
@ -1580,7 +1572,7 @@ def _optimize_post(model):
|
||||||
)
|
)
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.RMSNorm,
|
module.RMSNorm,
|
||||||
chatglm_rms_norm_forward)
|
rms_norm_forward)
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.QWenMLP,
|
module.QWenMLP,
|
||||||
qwen_mlp_forward)
|
qwen_mlp_forward)
|
||||||
|
|
|
||||||
|
|
@ -47,38 +47,6 @@ def pre_compute_inv_freq(module: torch.nn.Module):
|
||||||
module.register_buffer("inv_freq", inv_freq, persistent=False)
|
module.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
|
||||||
def baichuan_13b_rms_norm_forward(self, hidden_states):
|
|
||||||
if hidden_states.device.type == "xpu" and not (self.training or hidden_states.requires_grad):
|
|
||||||
import xe_addons
|
|
||||||
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
|
|
||||||
output = xe_addons.rms_norm(self.weight, x_2d, self.epsilon)
|
|
||||||
return output.reshape(hidden_states.shape)
|
|
||||||
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
|
|
||||||
return self.weight * hidden_states.to(input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def baichuan_mlp_forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
|
||||||
qtype = getattr(self.gate_proj, "qtype", None)
|
|
||||||
if mlp_fusion_check(x_2d, qtype, self.training):
|
|
||||||
import xe_linear
|
|
||||||
if not x_2d.is_contiguous():
|
|
||||||
x_2d = x_2d.contiguous()
|
|
||||||
return self.down_proj(xe_linear.mlp_forward_xpu(
|
|
||||||
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
|
||||||
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
|
||||||
SILU, qtype
|
|
||||||
))
|
|
||||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
||||||
|
|
||||||
|
|
||||||
def baichuan_model_7b_forward(
|
def baichuan_model_7b_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
|
|
||||||
|
|
@ -36,24 +36,13 @@ import math
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from transformers.models.bert.modeling_bert import BertSelfAttention, BertEncoder
|
from transformers.models.bert.modeling_bert import BertSelfAttention, BertEncoder
|
||||||
|
from ipex_llm.transformers.models.common import merge_linear
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
if isinstance(module, BertSelfAttention):
|
if isinstance(module, BertSelfAttention):
|
||||||
q_w = module.query.weight.data
|
qkv = merge_linear([module.query, module.key, module.value])
|
||||||
k_w = module.key.weight.data
|
|
||||||
v_w = module.value.weight.data
|
|
||||||
q_b = module.query.bias.data
|
|
||||||
k_b = module.key.bias.data
|
|
||||||
v_b = module.value.bias.data
|
|
||||||
new_w = torch.cat([q_w, k_w, v_w], dim=0)
|
|
||||||
new_b = torch.cat([q_b, k_b, v_b], dim=-1)
|
|
||||||
qkv = torch.nn.Linear(0, 0, bias=True)
|
|
||||||
qkv.weight = torch.nn.Parameter(new_w, requires_grad=False)
|
|
||||||
qkv.bias = torch.nn.Parameter(new_b, requires_grad=False)
|
|
||||||
qkv.in_features = module.query.in_features
|
|
||||||
qkv.out_features = module.query.out_features * 3
|
|
||||||
module.qkv = qkv
|
module.qkv = qkv
|
||||||
del module.query
|
del module.query
|
||||||
del module.key
|
del module.key
|
||||||
|
|
|
||||||
|
|
@ -33,34 +33,6 @@ from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cac
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states
|
|
||||||
go from (batch, num_key_value_heads, seqlen, head_dim) to
|
|
||||||
(batch, num_attention_heads, seqlen, head_dim)
|
|
||||||
"""
|
|
||||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
||||||
if n_rep == 1:
|
|
||||||
return hidden_states
|
|
||||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
|
|
||||||
n_rep, slen, head_dim)
|
|
||||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
||||||
|
|
||||||
|
|
||||||
def chatglm_rms_norm_forward(self, hidden_states):
|
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
|
||||||
import xe_addons
|
|
||||||
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
|
|
||||||
output = xe_addons.rms_norm(self.weight, x_2d, self.eps)
|
|
||||||
return output.reshape(hidden_states.shape)
|
|
||||||
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
|
||||||
return self.weight * hidden_states.to(input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def chatglm2_model_forward(
|
def chatglm2_model_forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
|
|
||||||
|
|
@ -157,8 +157,10 @@ def rms_norm_forward(self, hidden_states: torch.Tensor):
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
if hasattr(self, "variance_epsilon"):
|
if hasattr(self, "variance_epsilon"):
|
||||||
eps = self.variance_epsilon
|
eps = self.variance_epsilon
|
||||||
else:
|
elif hasattr(self, "epsilon"):
|
||||||
eps = self.epsilon
|
eps = self.epsilon
|
||||||
|
else:
|
||||||
|
eps = self.eps
|
||||||
|
|
||||||
if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
|
if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
|
||||||
import xe_addons
|
import xe_addons
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue