Add fused mlp optimizations to glm4 models (#12360)
* Add fused mlp to glm4 models * Small fix
This commit is contained in:
parent
520af4e9b5
commit
1a6cbc473f
2 changed files with 12 additions and 8 deletions
|
|
@ -1055,13 +1055,15 @@ def _optimize_pre(model, qtype=None):
|
||||||
# chatglm2 and chatglm3
|
# chatglm2 and chatglm3
|
||||||
from ipex_llm.transformers.models.chatglm2 import split_mlp
|
from ipex_llm.transformers.models.chatglm2 import split_mlp
|
||||||
model.apply(split_mlp)
|
model.apply(split_mlp)
|
||||||
elif (
|
elif isinstance(model.config.eos_token_id, list):
|
||||||
isinstance(model.config.eos_token_id, list)
|
# glm4 family
|
||||||
and hasattr(model.transformer, "vision")
|
if hasattr(model.transformer, "vision"):
|
||||||
and model.config.num_layers != 40
|
if model.config.num_layers != 40:
|
||||||
):
|
|
||||||
from ipex_llm.transformers.models.chatglm4v import merge_qkv
|
from ipex_llm.transformers.models.chatglm4v import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
|
elif model.config.num_layers in [40, 28]:
|
||||||
|
from ipex_llm.transformers.models.chatglm2 import split_mlp
|
||||||
|
model.apply(split_mlp)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
@ -1463,9 +1465,11 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
from ipex_llm.transformers.models.chatglm4 import chatglm4_attention_forward
|
from ipex_llm.transformers.models.chatglm4 import chatglm4_attention_forward
|
||||||
from ipex_llm.transformers.models.chatglm4 import chatglm4_model_forward
|
from ipex_llm.transformers.models.chatglm4 import chatglm4_model_forward
|
||||||
from ipex_llm.transformers.models.chatglm4 import chatglm4_encoder_forward
|
from ipex_llm.transformers.models.chatglm4 import chatglm4_encoder_forward
|
||||||
|
from ipex_llm.transformers.models.chatglm2 import mlp_forward
|
||||||
convert_forward(model, module.SelfAttention, chatglm4_attention_forward)
|
convert_forward(model, module.SelfAttention, chatglm4_attention_forward)
|
||||||
convert_forward(model, module.ChatGLMModel, chatglm4_model_forward)
|
convert_forward(model, module.ChatGLMModel, chatglm4_model_forward)
|
||||||
convert_forward(model, module.GLMTransformer, chatglm4_encoder_forward)
|
convert_forward(model, module.GLMTransformer, chatglm4_encoder_forward)
|
||||||
|
convert_forward(model, module.MLP, mlp_forward)
|
||||||
elif "mpt" in model.config.model_type:
|
elif "mpt" in model.config.model_type:
|
||||||
if model.config.architectures is not None:
|
if model.config.architectures is not None:
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,7 @@ def chatglm4_model_forward(
|
||||||
if use_cache:
|
if use_cache:
|
||||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
|
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
|
||||||
inputs)
|
inputs)
|
||||||
if use_compress_kv and not isinstance(past_key_values,
|
if use_compress_kv and not isinstance(past_key_values,
|
||||||
DynamicCompressCache):
|
DynamicCompressCache):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue