parent
94c4ce389f
commit
1b637e4477
2 changed files with 53 additions and 1 deletions
|
|
@ -1049,6 +1049,12 @@ def _optimize_pre(model, qtype=None):
|
||||||
model.llm.config.model_type = "llama"
|
model.llm.config.model_type = "llama"
|
||||||
_optimize_pre(model.llm, qtype=qtype)
|
_optimize_pre(model.llm, qtype=qtype)
|
||||||
model.llm.config.model_type = "minicpmv"
|
model.llm.config.model_type = "minicpmv"
|
||||||
|
if model.config.architectures is not None \
|
||||||
|
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
|
||||||
|
from ipex_llm.transformers.models.chatglm2 import split_mlp
|
||||||
|
if hasattr(model.config, 'padded_vocab_size') and \
|
||||||
|
model.config.padded_vocab_size == 65024:
|
||||||
|
model.apply(split_mlp)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
@ -1372,6 +1378,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
||||||
from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward
|
from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward
|
||||||
from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward
|
from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward
|
||||||
|
from ipex_llm.transformers.models.chatglm2 import mlp_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.SelfAttention,
|
module.SelfAttention,
|
||||||
chatglm2_attention_forward)
|
chatglm2_attention_forward)
|
||||||
|
|
@ -1384,6 +1391,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.RMSNorm,
|
module.RMSNorm,
|
||||||
chatglm_rms_norm_forward)
|
chatglm_rms_norm_forward)
|
||||||
|
convert_forward(model, module.MLP, mlp_forward)
|
||||||
elif hasattr(model.config, 'padded_vocab_size') and \
|
elif hasattr(model.config, 'padded_vocab_size') and \
|
||||||
model.config.padded_vocab_size == 64896:
|
model.config.padded_vocab_size == 64896:
|
||||||
# codegeex-nano
|
# codegeex-nano
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from ipex_llm.utils.common.log4Error import invalidInputError
|
||||||
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
|
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
||||||
|
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \
|
||||||
use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36
|
use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36
|
||||||
from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache
|
from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache
|
||||||
|
|
@ -91,7 +92,7 @@ def chatglm2_model_forward(
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
|
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
|
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
|
||||||
input_ids)
|
input_ids)
|
||||||
if use_compress_kv and not isinstance(past_key_values,
|
if use_compress_kv and not isinstance(past_key_values,
|
||||||
DynamicCompressCache):
|
DynamicCompressCache):
|
||||||
|
|
@ -570,3 +571,46 @@ def codegeex_attention_forward(
|
||||||
output = self.dense(context_layer)
|
output = self.dense(context_layer)
|
||||||
|
|
||||||
return output, past_key_value
|
return output, past_key_value
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def split_mlp(module: torch.nn.Module):
|
||||||
|
if module.__class__.__name__ == "MLP":
|
||||||
|
gate_weight, up_weight = module.dense_h_to_4h.weight.data.chunk(2, dim=0)
|
||||||
|
|
||||||
|
gate_proj = torch.nn.Linear(0, 0, bias=False)
|
||||||
|
gate_proj.weight = torch.nn.Parameter(gate_weight, requires_grad=False)
|
||||||
|
gate_proj.in_features = gate_weight.size(1)
|
||||||
|
gate_proj.out_features = gate_weight.size(0)
|
||||||
|
|
||||||
|
up_proj = torch.nn.Linear(0, 0, bias=False)
|
||||||
|
up_proj.weight = torch.nn.Parameter(up_weight, requires_grad=False)
|
||||||
|
up_proj.in_features = up_weight.size(1)
|
||||||
|
up_proj.out_features = up_weight.size(0)
|
||||||
|
|
||||||
|
module.gate_proj = gate_proj
|
||||||
|
module.up_proj = up_proj
|
||||||
|
|
||||||
|
module.activation_fn = F.silu
|
||||||
|
|
||||||
|
del module.dense_h_to_4h
|
||||||
|
|
||||||
|
|
||||||
|
def mlp_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
x_2d = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
qtype = getattr(self.gate_proj, "qtype", None)
|
||||||
|
if mlp_fusion_check(x_2d, qtype, self.training):
|
||||||
|
x_2d = x_2d.contiguous()
|
||||||
|
import xe_linear
|
||||||
|
return self.dense_4h_to_h(xe_linear.mlp_forward_xpu(
|
||||||
|
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
||||||
|
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_features,
|
||||||
|
SILU, qtype
|
||||||
|
))
|
||||||
|
return self.dense_4h_to_h(
|
||||||
|
self.activation_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue