optimize internvl2 4b performance (#11720)

This commit is contained in:
Yishuo Wang 2024-08-06 14:25:08 +08:00 committed by GitHub
parent f44b732aa8
commit bbdff6edeb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -739,6 +739,8 @@ def _optimize_pre(model, qtype=None):
if model.config.model_type == "internlmxcomposer2": if model.config.model_type == "internlmxcomposer2":
from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
model.apply(pre_process_attn_and_mlp) model.apply(pre_process_attn_and_mlp)
if model.config.model_type == "internvl_chat":
_optimize_pre(model.language_model)
if model.config.model_type == "gemma2": if model.config.model_type == "gemma2":
from ipex_llm.transformers.models.gemma2 import merge_qkv from ipex_llm.transformers.models.gemma2 import merge_qkv
model.apply(merge_qkv) model.apply(merge_qkv)
@ -1268,6 +1270,7 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.internvl import _get_pos_embed from ipex_llm.transformers.models.internvl import _get_pos_embed
vision_embedding = model.vision_model.embeddings vision_embedding = model.vision_model.embeddings
vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding) vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding)
_optimize_post(model.language_model, lightweight_bmm=lightweight_bmm)
elif model.config.model_type == "qwen": elif model.config.model_type == "qwen":
if hasattr(model.config, "visual"): if hasattr(model.config, "visual"):
# for Qwen-VL-Chat # for Qwen-VL-Chat