optimize glm4v's vision part (#12346)

This commit is contained in:
Yishuo Wang 2024-11-06 15:43:40 +08:00 committed by GitHub
parent c8b7265359
commit e23ef7d088
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 51 additions and 33 deletions

View file

@ -932,12 +932,13 @@ def _optimize_pre(model, qtype=None):
logger.info("Only HuggingFace Transformers models are currently "
"supported for further optimizations")
return model
# for rwkv models (verified RWKV/rwkv-4-world-7b)
if model.config.model_type == "rwkv":
model.rwkv._rescale_layers()
model.rwkv.layers_are_rescaled = True
# process NormHead module in Baichuan2 7B and 13B
if model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
# NormHead do normalization on the weights just once at inference time.
# so we do it in advance and convert it to Linear so that it can be replaced.
# modeling_module_name = model.__class__.__module__
@ -958,30 +959,30 @@ def _optimize_pre(model, qtype=None):
from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
model.apply(pre_compute_inv_freq)
# for yuan 2.0
if model.config.model_type == "yuan":
elif model.config.model_type == "yuan":
from ipex_llm.transformers.models.yuan import merge_qk
model.apply(merge_qk)
# for bge-large
if model.config.model_type == 'bert' and (
elif model.config.model_type == 'bert' and (
not model.config.is_decoder and
model.config.position_embedding_type == "absolute"
):
from ipex_llm.transformers.models.bert import merge_qkv
model.apply(merge_qkv)
# for starcoder2
if model.config.model_type == "starcoder2":
elif model.config.model_type == "starcoder2":
from ipex_llm.transformers.models.starcoder2 import merge_qkv
model.apply(merge_qkv)
if model.config.model_type == "phi":
elif model.config.model_type == "phi":
from ipex_llm.transformers.models.phi import merge_qkv
model.apply(merge_qkv)
if model.config.model_type in ["phi3", "phi3_v"]:
elif model.config.model_type in ["phi3", "phi3_v"]:
from ipex_llm.transformers.models.phi3 import pre_compute_inv_freq
model.apply(pre_compute_inv_freq)
from ipex_llm.transformers.models.phi3 import split_mlp
model.apply(split_mlp)
# for qwen2
if model.config.model_type == "qwen2":
elif model.config.model_type == "qwen2":
# Skip merge_qkv and padding_mlp if quant_method is 'gptq'
should_apply_merge_qkv = (
not hasattr(model.config, "quantization_config") or
@ -994,51 +995,51 @@ def _optimize_pre(model, qtype=None):
if qtype != ggml_tensor_qtype["fp6"]:
from ipex_llm.transformers.models.qwen2 import padding_mlp
model.apply(padding_mlp)
if model.config.model_type == "qwen2_moe":
elif model.config.model_type == "qwen2_moe":
from ipex_llm.transformers.models.qwen2_moe import merge_qkv
model.apply(merge_qkv)
if model.config.model_type == "qwen2_audio":
elif model.config.model_type == "qwen2_audio":
from ipex_llm.transformers.models.qwen2 import merge_qkv
model.language_model.apply(merge_qkv)
if model.config.model_type == "qwen2_vl":
elif model.config.model_type == "qwen2_vl":
from ipex_llm.transformers.models.qwen2_vl import merge_qkv
model.apply(merge_qkv)
if model.config.model_type == "stablelm":
elif model.config.model_type == "stablelm":
# For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
from ipex_llm.transformers.models.stablelm import merge_qkv
model.apply(merge_qkv)
# for internlm
if model.config.model_type == "internlm":
elif model.config.model_type == "internlm":
from ipex_llm.transformers.models.internlm import merge_qkv
model.apply(merge_qkv)
# for internlm-xcomposer2-vl
if model.config.model_type == "internlmxcomposer2":
elif model.config.model_type == "internlmxcomposer2":
from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
model.apply(pre_process_attn_and_mlp)
if model.config.model_type == "internvl_chat":
elif model.config.model_type == "internvl_chat":
_optimize_pre(model.language_model, qtype=qtype)
if model.config.model_type == "gemma":
elif model.config.model_type == "gemma":
from ipex_llm.transformers.models.gemma import merge_qkv, pre_compute_inv_freq
model.apply(merge_qkv)
model.apply(pre_compute_inv_freq)
if model.config.model_type == "gemma2":
elif model.config.model_type == "gemma2":
from ipex_llm.transformers.models.gemma2 import merge_qkv
model.apply(merge_qkv)
if model.config.model_type == "llama":
elif model.config.model_type == "llama":
from ipex_llm.transformers.models.llama import merge_qkv
model.apply(merge_qkv)
if model.config.model_type == "mllama":
elif model.config.model_type == "mllama":
from ipex_llm.transformers.models.mllama import merge_qkv
model.apply(merge_qkv)
if model.config.model_type == "minicpm":
elif model.config.model_type == "minicpm":
from ipex_llm.transformers.models.minicpm import merge_qkv
model.apply(merge_qkv)
if model.config.model_type == "minicpm3":
elif model.config.model_type == "minicpm3":
from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq
model.apply(pre_compute_inv_freq)
from ipex_llm.transformers.models.minicpm3 import padding_v_head_dim
model.apply(padding_v_head_dim)
if model.config.model_type == "minicpmv":
elif model.config.model_type == "minicpmv":
from ipex_llm.transformers.models.minicpmv import merge_qkv
model.vpm.apply(merge_qkv)
if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
@ -1049,12 +1050,18 @@ def _optimize_pre(model, qtype=None):
model.llm.config.model_type = "llama"
_optimize_pre(model.llm, qtype=qtype)
model.llm.config.model_type = "minicpmv"
if model.config.architectures is not None \
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
elif model.config.model_type == "chatglm":
if hasattr(model.config, 'padded_vocab_size') and model.config.padded_vocab_size == 65024:
# chatglm2 and chatglm3
from ipex_llm.transformers.models.chatglm2 import split_mlp
if hasattr(model.config, 'padded_vocab_size') and \
model.config.padded_vocab_size == 65024:
model.apply(split_mlp)
elif (
isinstance(model.config.eos_token_id, list)
and hasattr(model.transformer, "vision")
and model.config.num_layers != 40
):
from ipex_llm.transformers.models.chatglm4v import merge_qkv
model.apply(merge_qkv)
return model
@ -1426,20 +1433,18 @@ def _optimize_post(model, lightweight_bmm=False):
# glm4 family
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward)
if hasattr(model.transformer, "vision"):
# glm4 vision family
modeling_module_name = model.transformer.vision.__class__.__module__
vision_module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.chatglm4v import chatglm4v_attention_forward
from ipex_llm.transformers.models.chatglm4v import chatglm4v_model_forward
convert_forward(model, module.SelfAttention, chatglm4v_attention_forward)
convert_forward(model, module.ChatGLMModel, chatglm4v_model_forward)
modeling_module_name = model.transformer.vision.__class__.__module__
vision_module = importlib.import_module(modeling_module_name)
if model.config.num_layers == 40:
# glm-4v-9b
from ipex_llm.transformers.models.chatglm4v import visual_attention_forward
@ -1447,8 +1452,11 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, vision_module.Attention, visual_attention_forward)
convert_forward(model, vision_module.PatchEmbedding, patch_embedding_forward)
else:
# todo
pass
from transformers.models.siglip.modeling_siglip import SiglipAttention
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
convert_forward(model, SiglipAttention, siglip_attention_forward)
from ipex_llm.transformers.models.chatglm4v import vision_model_forward
convert_forward(model, vision_module.VisionModel, vision_model_forward)
elif model.config.num_layers == 40:
# glm-4-9b

View file

@ -19,6 +19,7 @@
import torch
from typing import Optional, Tuple, Union
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
@ -339,3 +340,12 @@ def patch_embedding_forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L,
x = torch.cat((cls_token, x), dim=1)
x += self.position_embedding.weight.unsqueeze(0).to(images.device)
return x
def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, "SiglipAttention")
def vision_model_forward(self: torch.nn.Module, image: torch.Tensor):
vit_output = self.vit(image)
return self.adapter(vit_output)