rewrite minicpmv optimization (#11816)
This commit is contained in:
parent
447c8ed324
commit
4e178f0c5d
2 changed files with 140 additions and 1064 deletions
|
|
@ -747,13 +747,16 @@ def _optimize_pre(model, qtype=None):
|
||||||
if model.config.model_type == "llama":
|
if model.config.model_type == "llama":
|
||||||
from ipex_llm.transformers.models.llama import merge_qkv
|
from ipex_llm.transformers.models.llama import merge_qkv
|
||||||
model.apply(merge_qkv)
|
model.apply(merge_qkv)
|
||||||
|
if model.config.model_type == "minicpm":
|
||||||
|
from ipex_llm.transformers.models.minicpm import merge_qkv
|
||||||
|
model.apply(merge_qkv)
|
||||||
if model.config.model_type == "minicpmv":
|
if model.config.model_type == "minicpmv":
|
||||||
from ipex_llm.transformers.models.minicpmv import merge_qkv
|
from ipex_llm.transformers.models.minicpmv import merge_qkv
|
||||||
model.vpm.apply(merge_qkv)
|
model.vpm.apply(merge_qkv)
|
||||||
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
|
||||||
|
model.llm.config.model_type = "minicpm"
|
||||||
|
elif model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
||||||
model.llm.config.model_type = "qwen2"
|
model.llm.config.model_type = "qwen2"
|
||||||
_optimize_pre(model.llm, qtype=qtype)
|
|
||||||
model.llm.config.model_type = "minicpmv"
|
|
||||||
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
|
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
|
||||||
model.llm.config.model_type = "llama"
|
model.llm.config.model_type = "llama"
|
||||||
_optimize_pre(model.llm, qtype=qtype)
|
_optimize_pre(model.llm, qtype=qtype)
|
||||||
|
|
@ -1699,31 +1702,16 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module.StableLmModel,
|
module.StableLmModel,
|
||||||
stablelm_model_forward
|
stablelm_model_forward
|
||||||
)
|
)
|
||||||
elif model.config.model_type == 'minicpm':
|
elif model.config.model_type == "minicpm":
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
if version.parse(trans_version) >= version.parse("4.39.0"):
|
|
||||||
from ipex_llm.transformers.models.minicpm import minicpm_attention_forward_4_39
|
|
||||||
convert_forward(model,
|
|
||||||
module.MiniCPMAttention,
|
|
||||||
minicpm_attention_forward_4_39)
|
|
||||||
else:
|
|
||||||
from ipex_llm.transformers.models.minicpm import minicpm_attention_forward
|
from ipex_llm.transformers.models.minicpm import minicpm_attention_forward
|
||||||
convert_forward(model,
|
from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper
|
||||||
module.MiniCPMAttention,
|
convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward)
|
||||||
minicpm_attention_forward)
|
convert_forward(model, module.MiniCPMMLP, llama_mlp_forward)
|
||||||
from ipex_llm.transformers.models.minicpm import minicpm_model_forward
|
convert_forward(model, module.MiniCPMRMSNorm, llama_rms_norm_forward)
|
||||||
|
minicpm_model_forward = minicpm_model_forward_wrapper(module.MiniCPMModel.forward)
|
||||||
convert_forward(model,
|
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
|
||||||
module.MiniCPMMLP,
|
|
||||||
llama_mlp_forward)
|
|
||||||
convert_forward(model,
|
|
||||||
module.MiniCPMRMSNorm,
|
|
||||||
llama_rms_norm_forward)
|
|
||||||
|
|
||||||
convert_forward(model,
|
|
||||||
module.MiniCPMModel,
|
|
||||||
minicpm_model_forward)
|
|
||||||
elif model.config.model_type == "minicpmv":
|
elif model.config.model_type == "minicpmv":
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
|
@ -1734,13 +1722,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
|
if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
|
||||||
# MiniCPM-V 2
|
# MiniCPM-V 2
|
||||||
model.llm.config.model_type = "minicpm"
|
model.llm.config.model_type = "minicpm"
|
||||||
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
|
elif model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
||||||
model.llm.config.model_type = "minicpmv"
|
|
||||||
if model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
|
|
||||||
# MiniCPM-V 2.6
|
# MiniCPM-V 2.6
|
||||||
model.llm.config.model_type = "qwen2"
|
model.llm.config.model_type = "qwen2"
|
||||||
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
|
|
||||||
model.llm.config.model_type = "minicpmv"
|
|
||||||
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
|
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
|
||||||
# MiniCPM-V 2.5
|
# MiniCPM-V 2.5
|
||||||
model.llm.config.model_type = "llama"
|
model.llm.config.model_type = "llama"
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue