Add chatglm2&3 fuse mlp (#12328)

* add chatglm fuse mlp
This commit is contained in:
Zhao Changmin 2024-11-04 18:04:41 +08:00 committed by GitHub
parent 94c4ce389f
commit 1b637e4477
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 53 additions and 1 deletions

View file

@ -1049,6 +1049,12 @@ def _optimize_pre(model, qtype=None):
model.llm.config.model_type = "llama" model.llm.config.model_type = "llama"
_optimize_pre(model.llm, qtype=qtype) _optimize_pre(model.llm, qtype=qtype)
model.llm.config.model_type = "minicpmv" model.llm.config.model_type = "minicpmv"
if model.config.architectures is not None \
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
from ipex_llm.transformers.models.chatglm2 import split_mlp
if hasattr(model.config, 'padded_vocab_size') and \
model.config.padded_vocab_size == 65024:
model.apply(split_mlp)
return model return model
@ -1372,6 +1378,7 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward
from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward
from ipex_llm.transformers.models.chatglm2 import mlp_forward
convert_forward(model, convert_forward(model,
module.SelfAttention, module.SelfAttention,
chatglm2_attention_forward) chatglm2_attention_forward)
@ -1384,6 +1391,7 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, convert_forward(model,
module.RMSNorm, module.RMSNorm,
chatglm_rms_norm_forward) chatglm_rms_norm_forward)
convert_forward(model, module.MLP, mlp_forward)
elif hasattr(model.config, 'padded_vocab_size') and \ elif hasattr(model.config, 'padded_vocab_size') and \
model.config.padded_vocab_size == 64896: model.config.padded_vocab_size == 64896:
# codegeex-nano # codegeex-nano

View file

@ -26,6 +26,7 @@ from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \ from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \
use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36 use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache
@ -91,7 +92,7 @@ def chatglm2_model_forward(
if use_cache: if use_cache:
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1]) use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h, use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
input_ids) input_ids)
if use_compress_kv and not isinstance(past_key_values, if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache): DynamicCompressCache):
@ -570,3 +571,46 @@ def codegeex_attention_forward(
output = self.dense(context_layer) output = self.dense(context_layer)
return output, past_key_value return output, past_key_value
import torch.nn.functional as F
def split_mlp(module: torch.nn.Module):
if module.__class__.__name__ == "MLP":
gate_weight, up_weight = module.dense_h_to_4h.weight.data.chunk(2, dim=0)
gate_proj = torch.nn.Linear(0, 0, bias=False)
gate_proj.weight = torch.nn.Parameter(gate_weight, requires_grad=False)
gate_proj.in_features = gate_weight.size(1)
gate_proj.out_features = gate_weight.size(0)
up_proj = torch.nn.Linear(0, 0, bias=False)
up_proj.weight = torch.nn.Parameter(up_weight, requires_grad=False)
up_proj.in_features = up_weight.size(1)
up_proj.out_features = up_weight.size(0)
module.gate_proj = gate_proj
module.up_proj = up_proj
module.activation_fn = F.silu
del module.dense_h_to_4h
def mlp_forward(
self,
hidden_states: torch.FloatTensor
) -> torch.FloatTensor:
x_2d = hidden_states.view(-1, hidden_states.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training):
x_2d = x_2d.contiguous()
import xe_linear
return self.dense_4h_to_h(xe_linear.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_features,
SILU, qtype
))
return self.dense_4h_to_h(
self.activation_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)
)