diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 5c87c63c..5fc17437 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1049,6 +1049,12 @@ def _optimize_pre(model, qtype=None): model.llm.config.model_type = "llama" _optimize_pre(model.llm, qtype=qtype) model.llm.config.model_type = "minicpmv" + if model.config.architectures is not None \ + and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]: + from ipex_llm.transformers.models.chatglm2 import split_mlp + if hasattr(model.config, 'padded_vocab_size') and \ + model.config.padded_vocab_size == 65024: + model.apply(split_mlp) return model @@ -1372,6 +1378,7 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward + from ipex_llm.transformers.models.chatglm2 import mlp_forward convert_forward(model, module.SelfAttention, chatglm2_attention_forward) @@ -1384,6 +1391,7 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward) + convert_forward(model, module.MLP, mlp_forward) elif hasattr(model.config, 'padded_vocab_size') and \ model.config.padded_vocab_size == 64896: # codegeex-nano diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 43cfe816..9e213e17 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -26,6 +26,7 @@ from ipex_llm.utils.common.log4Error import invalidInputError from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb +from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \ use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache @@ -91,7 +92,7 @@ def chatglm2_model_forward( if use_cache: use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1]) - use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h, + use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj, input_ids) if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache): @@ -570,3 +571,46 @@ def codegeex_attention_forward( output = self.dense(context_layer) return output, past_key_value + +import torch.nn.functional as F + + +def split_mlp(module: torch.nn.Module): + if module.__class__.__name__ == "MLP": + gate_weight, up_weight = module.dense_h_to_4h.weight.data.chunk(2, dim=0) + + gate_proj = torch.nn.Linear(0, 0, bias=False) + gate_proj.weight = torch.nn.Parameter(gate_weight, requires_grad=False) + gate_proj.in_features = gate_weight.size(1) + gate_proj.out_features = gate_weight.size(0) + + up_proj = torch.nn.Linear(0, 0, bias=False) + up_proj.weight = torch.nn.Parameter(up_weight, requires_grad=False) + up_proj.in_features = up_weight.size(1) + up_proj.out_features = up_weight.size(0) + + module.gate_proj = gate_proj + module.up_proj = up_proj + + module.activation_fn = F.silu + + del module.dense_h_to_4h + + +def mlp_forward( + self, + hidden_states: torch.FloatTensor +) -> torch.FloatTensor: + x_2d = hidden_states.view(-1, hidden_states.shape[-1]) + qtype = getattr(self.gate_proj, "qtype", None) + if mlp_fusion_check(x_2d, qtype, self.training): + x_2d = x_2d.contiguous() + import xe_linear + return self.dense_4h_to_h(xe_linear.mlp_forward_xpu( + x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, + x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_features, + SILU, qtype + )) + return self.dense_4h_to_h( + self.activation_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states) + )