optimize internvl2 4b performance (#11720)
This commit is contained in:
parent
f44b732aa8
commit
bbdff6edeb
1 changed files with 3 additions and 0 deletions
|
|
@ -739,6 +739,8 @@ def _optimize_pre(model, qtype=None):
|
||||||
if model.config.model_type == "internlmxcomposer2":
|
if model.config.model_type == "internlmxcomposer2":
|
||||||
from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
|
from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
|
||||||
model.apply(pre_process_attn_and_mlp)
|
model.apply(pre_process_attn_and_mlp)
|
||||||
|
if model.config.model_type == "internvl_chat":
|
||||||
|
_optimize_pre(model.language_model)
|
||||||
if model.config.model_type == "gemma2":
|
if model.config.model_type == "gemma2":
|
||||||
from ipex_llm.transformers.models.gemma2 import merge_qkv
|
from ipex_llm.transformers.models.gemma2 import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
|
|
@ -1268,6 +1270,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
from ipex_llm.transformers.models.internvl import _get_pos_embed
|
from ipex_llm.transformers.models.internvl import _get_pos_embed
|
||||||
vision_embedding = model.vision_model.embeddings
|
vision_embedding = model.vision_model.embeddings
|
||||||
vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding)
|
vision_embedding._get_pos_embed = MethodType(_get_pos_embed, vision_embedding)
|
||||||
|
_optimize_post(model.language_model, lightweight_bmm=lightweight_bmm)
|
||||||
elif model.config.model_type == "qwen":
|
elif model.config.model_type == "qwen":
|
||||||
if hasattr(model.config, "visual"):
|
if hasattr(model.config, "visual"):
|
||||||
# for Qwen-VL-Chat
|
# for Qwen-VL-Chat
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue