diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index dac4e086..cd7d9101 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -932,12 +932,13 @@ def _optimize_pre(model, qtype=None): logger.info("Only HuggingFace Transformers models are currently " "supported for further optimizations") return model + # for rwkv models (verified RWKV/rwkv-4-world-7b) if model.config.model_type == "rwkv": model.rwkv._rescale_layers() model.rwkv.layers_are_rescaled = True # process NormHead module in Baichuan2 7B and 13B - if model.config.model_type == "baichuan" and model.config.vocab_size == 125696: + elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696: # NormHead do normalization on the weights just once at inference time. # so we do it in advance and convert it to Linear so that it can be replaced. # modeling_module_name = model.__class__.__module__ @@ -958,30 +959,30 @@ def _optimize_pre(model, qtype=None): from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq model.apply(pre_compute_inv_freq) # for yuan 2.0 - if model.config.model_type == "yuan": + elif model.config.model_type == "yuan": from ipex_llm.transformers.models.yuan import merge_qk model.apply(merge_qk) # for bge-large - if model.config.model_type == 'bert' and ( + elif model.config.model_type == 'bert' and ( not model.config.is_decoder and model.config.position_embedding_type == "absolute" ): from ipex_llm.transformers.models.bert import merge_qkv model.apply(merge_qkv) # for starcoder2 - if model.config.model_type == "starcoder2": + elif model.config.model_type == "starcoder2": from ipex_llm.transformers.models.starcoder2 import merge_qkv model.apply(merge_qkv) - if model.config.model_type == "phi": + elif model.config.model_type == "phi": from ipex_llm.transformers.models.phi import merge_qkv model.apply(merge_qkv) - if model.config.model_type in ["phi3", "phi3_v"]: + elif model.config.model_type in ["phi3", "phi3_v"]: from ipex_llm.transformers.models.phi3 import pre_compute_inv_freq model.apply(pre_compute_inv_freq) from ipex_llm.transformers.models.phi3 import split_mlp model.apply(split_mlp) # for qwen2 - if model.config.model_type == "qwen2": + elif model.config.model_type == "qwen2": # Skip merge_qkv and padding_mlp if quant_method is 'gptq' should_apply_merge_qkv = ( not hasattr(model.config, "quantization_config") or @@ -994,51 +995,51 @@ def _optimize_pre(model, qtype=None): if qtype != ggml_tensor_qtype["fp6"]: from ipex_llm.transformers.models.qwen2 import padding_mlp model.apply(padding_mlp) - if model.config.model_type == "qwen2_moe": + elif model.config.model_type == "qwen2_moe": from ipex_llm.transformers.models.qwen2_moe import merge_qkv model.apply(merge_qkv) - if model.config.model_type == "qwen2_audio": + elif model.config.model_type == "qwen2_audio": from ipex_llm.transformers.models.qwen2 import merge_qkv model.language_model.apply(merge_qkv) - if model.config.model_type == "qwen2_vl": + elif model.config.model_type == "qwen2_vl": from ipex_llm.transformers.models.qwen2_vl import merge_qkv model.apply(merge_qkv) - if model.config.model_type == "stablelm": + elif model.config.model_type == "stablelm": # For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b from ipex_llm.transformers.models.stablelm import merge_qkv model.apply(merge_qkv) # for internlm - if model.config.model_type == "internlm": + elif model.config.model_type == "internlm": from ipex_llm.transformers.models.internlm import merge_qkv model.apply(merge_qkv) # for internlm-xcomposer2-vl - if model.config.model_type == "internlmxcomposer2": + elif model.config.model_type == "internlmxcomposer2": from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp model.apply(pre_process_attn_and_mlp) - if model.config.model_type == "internvl_chat": + elif model.config.model_type == "internvl_chat": _optimize_pre(model.language_model, qtype=qtype) - if model.config.model_type == "gemma": + elif model.config.model_type == "gemma": from ipex_llm.transformers.models.gemma import merge_qkv, pre_compute_inv_freq model.apply(merge_qkv) model.apply(pre_compute_inv_freq) - if model.config.model_type == "gemma2": + elif model.config.model_type == "gemma2": from ipex_llm.transformers.models.gemma2 import merge_qkv model.apply(merge_qkv) - if model.config.model_type == "llama": + elif model.config.model_type == "llama": from ipex_llm.transformers.models.llama import merge_qkv model.apply(merge_qkv) - if model.config.model_type == "mllama": + elif model.config.model_type == "mllama": from ipex_llm.transformers.models.mllama import merge_qkv model.apply(merge_qkv) - if model.config.model_type == "minicpm": + elif model.config.model_type == "minicpm": from ipex_llm.transformers.models.minicpm import merge_qkv model.apply(merge_qkv) - if model.config.model_type == "minicpm3": + elif model.config.model_type == "minicpm3": from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq model.apply(pre_compute_inv_freq) from ipex_llm.transformers.models.minicpm3 import padding_v_head_dim model.apply(padding_v_head_dim) - if model.config.model_type == "minicpmv": + elif model.config.model_type == "minicpmv": from ipex_llm.transformers.models.minicpmv import merge_qkv model.vpm.apply(merge_qkv) if model.config.hidden_size == 2304 and model.config.vocab_size == 122753: @@ -1049,12 +1050,18 @@ def _optimize_pre(model, qtype=None): model.llm.config.model_type = "llama" _optimize_pre(model.llm, qtype=qtype) model.llm.config.model_type = "minicpmv" - if model.config.architectures is not None \ - and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]: - from ipex_llm.transformers.models.chatglm2 import split_mlp - if hasattr(model.config, 'padded_vocab_size') and \ - model.config.padded_vocab_size == 65024: - model.apply(split_mlp) + elif model.config.model_type == "chatglm": + if hasattr(model.config, 'padded_vocab_size') and model.config.padded_vocab_size == 65024: + # chatglm2 and chatglm3 + from ipex_llm.transformers.models.chatglm2 import split_mlp + model.apply(split_mlp) + elif ( + isinstance(model.config.eos_token_id, list) + and hasattr(model.transformer, "vision") + and model.config.num_layers != 40 + ): + from ipex_llm.transformers.models.chatglm4v import merge_qkv + model.apply(merge_qkv) return model @@ -1426,20 +1433,18 @@ def _optimize_post(model, lightweight_bmm=False): # glm4 family modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) - from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward) if hasattr(model.transformer, "vision"): # glm4 vision family - modeling_module_name = model.transformer.vision.__class__.__module__ - vision_module = importlib.import_module(modeling_module_name) - from ipex_llm.transformers.models.chatglm4v import chatglm4v_attention_forward from ipex_llm.transformers.models.chatglm4v import chatglm4v_model_forward convert_forward(model, module.SelfAttention, chatglm4v_attention_forward) convert_forward(model, module.ChatGLMModel, chatglm4v_model_forward) + modeling_module_name = model.transformer.vision.__class__.__module__ + vision_module = importlib.import_module(modeling_module_name) if model.config.num_layers == 40: # glm-4v-9b from ipex_llm.transformers.models.chatglm4v import visual_attention_forward @@ -1447,8 +1452,11 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, vision_module.Attention, visual_attention_forward) convert_forward(model, vision_module.PatchEmbedding, patch_embedding_forward) else: - # todo - pass + from transformers.models.siglip.modeling_siglip import SiglipAttention + from ipex_llm.transformers.models.minicpmv import siglip_attention_forward + convert_forward(model, SiglipAttention, siglip_attention_forward) + from ipex_llm.transformers.models.chatglm4v import vision_model_forward + convert_forward(model, vision_module.VisionModel, vision_model_forward) elif model.config.num_layers == 40: # glm-4-9b diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4v.py b/python/llm/src/ipex_llm/transformers/models/chatglm4v.py index a315124b..2028cae0 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4v.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4v.py @@ -19,6 +19,7 @@ import torch from typing import Optional, Tuple, Union +from ipex_llm.transformers.models.common import merge_qkv_base from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb @@ -339,3 +340,12 @@ def patch_embedding_forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, x = torch.cat((cls_token, x), dim=1) x += self.position_embedding.weight.unsqueeze(0).to(images.device) return x + + +def merge_qkv(module: torch.nn.Module): + merge_qkv_base(module, "SiglipAttention") + + +def vision_model_forward(self: torch.nn.Module, image: torch.Tensor): + vit_output = self.vit(image) + return self.adapter(vit_output)