Add basic glm4v support (#12345)
This commit is contained in:
parent
69e3a56943
commit
c8b7265359
1 changed files with 26 additions and 36 deletions
|
|
@ -1422,52 +1422,42 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module.SelfAttention,
|
module.SelfAttention,
|
||||||
chatglm_attention_forward
|
chatglm_attention_forward
|
||||||
)
|
)
|
||||||
elif model.config.num_layers == 40 and hasattr(model.config, 'rope_ratio'):
|
elif isinstance(model.config.eos_token_id, list):
|
||||||
|
# glm4 family
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
|
||||||
|
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
||||||
|
convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward)
|
||||||
|
|
||||||
if hasattr(model.transformer, "vision"):
|
if hasattr(model.transformer, "vision"):
|
||||||
# glm-4v-9b
|
# glm4 vision family
|
||||||
modeling_module_name = model.transformer.vision.__class__.__module__
|
modeling_module_name = model.transformer.vision.__class__.__module__
|
||||||
vision_module = importlib.import_module(modeling_module_name)
|
vision_module = importlib.import_module(modeling_module_name)
|
||||||
|
|
||||||
from ipex_llm.transformers.models.chatglm4v import chatglm4v_attention_forward
|
from ipex_llm.transformers.models.chatglm4v import chatglm4v_attention_forward
|
||||||
from ipex_llm.transformers.models.chatglm4v import chatglm4v_model_forward
|
from ipex_llm.transformers.models.chatglm4v import chatglm4v_model_forward
|
||||||
from ipex_llm.transformers.models.chatglm4v import visual_attention_forward
|
convert_forward(model, module.SelfAttention, chatglm4v_attention_forward)
|
||||||
from ipex_llm.transformers.models.chatglm4v import patch_embedding_forward
|
convert_forward(model, module.ChatGLMModel, chatglm4v_model_forward)
|
||||||
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
|
||||||
convert_forward(model,
|
if model.config.num_layers == 40:
|
||||||
module.SelfAttention,
|
# glm-4v-9b
|
||||||
chatglm4v_attention_forward)
|
from ipex_llm.transformers.models.chatglm4v import visual_attention_forward
|
||||||
convert_forward(model,
|
from ipex_llm.transformers.models.chatglm4v import patch_embedding_forward
|
||||||
module.ChatGLMModel,
|
convert_forward(model, vision_module.Attention, visual_attention_forward)
|
||||||
chatglm4v_model_forward)
|
convert_forward(model, vision_module.PatchEmbedding, patch_embedding_forward)
|
||||||
convert_forward(model,
|
else:
|
||||||
module.RMSNorm,
|
# todo
|
||||||
chatglm_rms_norm_forward)
|
pass
|
||||||
convert_forward(model,
|
|
||||||
vision_module.Attention,
|
elif model.config.num_layers == 40:
|
||||||
visual_attention_forward)
|
# glm-4-9b
|
||||||
convert_forward(model,
|
|
||||||
vision_module.PatchEmbedding,
|
|
||||||
patch_embedding_forward)
|
|
||||||
else:
|
|
||||||
# glm-4-9b-chat
|
|
||||||
from ipex_llm.transformers.models.chatglm4 import chatglm4_attention_forward
|
from ipex_llm.transformers.models.chatglm4 import chatglm4_attention_forward
|
||||||
from ipex_llm.transformers.models.chatglm4 import chatglm4_model_forward
|
from ipex_llm.transformers.models.chatglm4 import chatglm4_model_forward
|
||||||
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
|
||||||
from ipex_llm.transformers.models.chatglm4 import chatglm4_encoder_forward
|
from ipex_llm.transformers.models.chatglm4 import chatglm4_encoder_forward
|
||||||
convert_forward(model,
|
convert_forward(model, module.SelfAttention, chatglm4_attention_forward)
|
||||||
module.SelfAttention,
|
convert_forward(model, module.ChatGLMModel, chatglm4_model_forward)
|
||||||
chatglm4_attention_forward)
|
convert_forward(model, module.GLMTransformer, chatglm4_encoder_forward)
|
||||||
convert_forward(model,
|
|
||||||
module.ChatGLMModel,
|
|
||||||
chatglm4_model_forward)
|
|
||||||
convert_forward(model,
|
|
||||||
module.RMSNorm,
|
|
||||||
chatglm_rms_norm_forward)
|
|
||||||
convert_forward(model,
|
|
||||||
module.GLMTransformer,
|
|
||||||
chatglm4_encoder_forward)
|
|
||||||
|
|
||||||
elif "mpt" in model.config.model_type:
|
elif "mpt" in model.config.model_type:
|
||||||
if model.config.architectures is not None:
|
if model.config.architectures is not None:
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue