refactor llama convert to fix minicpm-v 2.5 optimization (#11783)

This commit is contained in:
Yishuo Wang 2024-08-14 09:29:57 +08:00 committed by GitHub
parent 7cd6ec9723
commit cb79dcda93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -754,6 +754,10 @@ def _optimize_pre(model, qtype=None):
model.llm.config.model_type = "qwen2"
_optimize_pre(model.llm, qtype=qtype)
model.llm.config.model_type = "minicpmv"
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
model.llm.config.model_type = "llama"
_optimize_pre(model.llm, qtype=qtype)
model.llm.config.model_type = "minicpmv"
return model
@ -933,16 +937,6 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
def _optimize_post(model, lightweight_bmm=False):
from packaging import version
from ipex_llm.transformers.models.llama import llama_attention_forward_4_31
from ipex_llm.transformers.models.llama import llama_attention_selective_batching_forward_4_31
from ipex_llm.transformers.models.llama import llama_model_selective_batching_forward_4_31
from ipex_llm.transformers.models.llama import llama_rms_norm_forward
from ipex_llm.transformers.models.llama import llama_mlp_forward
from ipex_llm.transformers.models.llama import llama_decoder_forward
from ipex_llm.transformers.models.llama import llama_model_forward
from transformers.modeling_utils import PreTrainedModel
try:
from sentence_transformers.SentenceTransformer import SentenceTransformer
if isinstance(model, SentenceTransformer):
@ -961,110 +955,80 @@ def _optimize_post(model, lightweight_bmm=False):
except ModuleNotFoundError:
pass
from transformers.modeling_utils import PreTrainedModel
# All huggingface format models are inherited from `PreTrainedModel`
if not isinstance(model, PreTrainedModel):
logger.info("Only HuggingFace Transformers models are currently "
"supported for further optimizations")
return model
vllm_selective_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING")
enable_vllm_se_batching = vllm_selective_batching is not None
enable_vllm_se_batching = enable_vllm_se_batching and vllm_selective_batching.lower() == "true"
from packaging import version
trans_version = transformers.__version__
if version.parse(trans_version) >= version.parse("4.31.0"):
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaRMSNorm,
llama_rms_norm_forward,)
convert_forward(model,
transformers.models.llama.modeling_llama.LlamaMLP,
llama_mlp_forward)
convert_forward(model,
transformers.models.llama.modeling_llama.LlamaDecoderLayer,
llama_decoder_forward)
if version.parse(trans_version) >= version.parse("4.36.0"):
# transformers version >= 4.36.0
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38
if version.parse(trans_version) >= version.parse("4.38.0"):
if version.parse(trans_version) >= version.parse("4.41.0"):
from ipex_llm.transformers.models.llama import llama_model_forward_4_41
from ipex_llm.transformers.models.llama import llama_attention_forward_4_41
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_forward_4_41)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_41)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaSdpaAttention,
llama_attention_forward_4_41)
else:
from ipex_llm.transformers.models.llama import llama_model_forward_4_38
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_forward_4_38)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_38)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaSdpaAttention,
llama_attention_forward_4_38)
else:
from ipex_llm.transformers.models.llama import llama_model_forward_4_36
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_forward_4_36)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_38)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaSdpaAttention,
llama_attention_forward_4_38)
else:
# transformers version between 4.31.0 - 4.35.2
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_31, )
if enable_vllm_se_batching:
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_selective_batching_forward_4_31,
)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_selective_batching_forward_4_31,
)
else:
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_forward)
else:
# todo implement 4.28.0 ~ 4.30.2
pass
# convert all nn.LayerNorm
from ipex_llm.transformers.models.bloom import bloom_layer_norm_forward
convert_forward(model,
nn.LayerNorm,
bloom_layer_norm_forward)
from ipex_llm.transformers.models.llama import llama_rms_norm_forward
from ipex_llm.transformers.models.llama import llama_mlp_forward
if model.config.architectures is not None \
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
if model.config.model_type == "llama":
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.models.llama.modeling_llama import LlamaMLP
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.llama.modeling_llama import LlamaModel
if version.parse(trans_version) >= version.parse("4.36.0"):
from transformers.models.llama.modeling_llama import LlamaSdpaAttention
from ipex_llm.transformers.models.llama import llama_rms_norm_forward
from ipex_llm.transformers.models.llama import llama_mlp_forward
from ipex_llm.transformers.models.llama import llama_decoder_forward
convert_forward(model, LlamaRMSNorm, llama_rms_norm_forward)
convert_forward(model, LlamaMLP, llama_mlp_forward)
convert_forward(model, LlamaDecoderLayer, llama_decoder_forward)
if version.parse(trans_version) >= version.parse("4.41.0"):
from ipex_llm.transformers.models.llama import llama_model_forward_4_41
from ipex_llm.transformers.models.llama import llama_attention_forward_4_41
convert_forward(model, LlamaModel, llama_model_forward_4_41)
convert_forward(model, LlamaAttention, llama_attention_forward_4_41)
convert_forward(model, LlamaSdpaAttention, llama_attention_forward_4_41)
elif version.parse(trans_version) >= version.parse("4.38.0"):
from ipex_llm.transformers.models.llama import llama_model_forward_4_38
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38
convert_forward(model, LlamaModel, llama_model_forward_4_38)
convert_forward(model, LlamaAttention, llama_attention_forward_4_38)
convert_forward(model, LlamaSdpaAttention, llama_attention_forward_4_38)
elif version.parse(trans_version) >= version.parse("4.36.0"):
from ipex_llm.transformers.models.llama import llama_model_forward_4_36
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38
convert_forward(model, LlamaModel, llama_model_forward_4_36)
convert_forward(model, LlamaAttention, llama_attention_forward_4_38)
convert_forward(model, LlamaSdpaAttention, llama_attention_forward_4_38)
else:
vllm_se_batching = os.getenv("VLLM_ENABLE_SELECTIVE_BATCHING", "").lower() == "true"
if vllm_se_batching:
from ipex_llm.transformers.models.llama import (
llama_model_selective_batching_forward_4_31,
llama_attention_selective_batching_forward_4_31,
)
convert_forward(model, LlamaModel,
llama_model_selective_batching_forward_4_31)
convert_forward(model, LlamaAttention,
llama_attention_selective_batching_forward_4_31)
else:
from ipex_llm.transformers.models.llama import llama_model_forward
from ipex_llm.transformers.models.llama import llama_attention_forward_4_31
convert_forward(model, LlamaModel, llama_model_forward)
convert_forward(model, LlamaAttention, llama_attention_forward_4_31)
elif (
model.config.architectures is not None
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]
):
if hasattr(model.config, 'padded_vocab_size') and \
model.config.padded_vocab_size in [65024, 64896]:
# chatglm2-6b, chatglm2-6b-32k, chatglm3-6b, chatglm3-6b-32k, chatglm3-6b-128k
@ -1370,6 +1334,7 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.qwen2_moe import qwen2moe_model_forward
from ipex_llm.transformers.models.qwen2_moe import qwen2_moe_causal_lm_forward
from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward
convert_forward(model,
module.Qwen2MoeModel,
qwen2moe_model_forward)
@ -1384,7 +1349,7 @@ def _optimize_post(model, lightweight_bmm=False):
qwen2moe_moeblock_forward)
convert_forward(model,
module.Qwen2MoeMLP,
llama_mlp_forward)
qwen2_mlp_forward)
convert_forward(model,
module.Qwen2MoeAttention,
qwen2_attention_forward)
@ -1768,7 +1733,9 @@ def _optimize_post(model, lightweight_bmm=False):
model.llm.config.model_type = "minicpmv"
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
# MiniCPM-V 2.5
pass
model.llm.config.model_type = "llama"
_optimize_post(model.llm, lightweight_bmm=lightweight_bmm)
model.llm.config.model_type = "minicpmv"
vpm_modeling_module_name = model.vpm.__class__.__module__
vpm_module = importlib.import_module(vpm_modeling_module_name)