optimize glm4v's vision part (#12346)
This commit is contained in:
		
							parent
							
								
									c8b7265359
								
							
						
					
					
						commit
						e23ef7d088
					
				
					 2 changed files with 51 additions and 33 deletions
				
			
		| 
						 | 
				
			
			@ -932,12 +932,13 @@ def _optimize_pre(model, qtype=None):
 | 
			
		|||
        logger.info("Only HuggingFace Transformers models are currently "
 | 
			
		||||
                    "supported for further optimizations")
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    # for rwkv models (verified RWKV/rwkv-4-world-7b)
 | 
			
		||||
    if model.config.model_type == "rwkv":
 | 
			
		||||
        model.rwkv._rescale_layers()
 | 
			
		||||
        model.rwkv.layers_are_rescaled = True
 | 
			
		||||
    # process NormHead module in Baichuan2 7B and 13B
 | 
			
		||||
    if model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
 | 
			
		||||
    elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
 | 
			
		||||
        # NormHead do normalization on the weights just once at inference time.
 | 
			
		||||
        # so we do it in advance and convert it to Linear so that it can be replaced.
 | 
			
		||||
        # modeling_module_name = model.__class__.__module__
 | 
			
		||||
| 
						 | 
				
			
			@ -958,30 +959,30 @@ def _optimize_pre(model, qtype=None):
 | 
			
		|||
            from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
 | 
			
		||||
            model.apply(pre_compute_inv_freq)
 | 
			
		||||
    # for yuan 2.0
 | 
			
		||||
    if model.config.model_type == "yuan":
 | 
			
		||||
    elif model.config.model_type == "yuan":
 | 
			
		||||
        from ipex_llm.transformers.models.yuan import merge_qk
 | 
			
		||||
        model.apply(merge_qk)
 | 
			
		||||
    # for bge-large
 | 
			
		||||
    if model.config.model_type == 'bert' and (
 | 
			
		||||
    elif model.config.model_type == 'bert' and (
 | 
			
		||||
        not model.config.is_decoder and
 | 
			
		||||
        model.config.position_embedding_type == "absolute"
 | 
			
		||||
    ):
 | 
			
		||||
        from ipex_llm.transformers.models.bert import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    # for starcoder2
 | 
			
		||||
    if model.config.model_type == "starcoder2":
 | 
			
		||||
    elif model.config.model_type == "starcoder2":
 | 
			
		||||
        from ipex_llm.transformers.models.starcoder2 import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    if model.config.model_type == "phi":
 | 
			
		||||
    elif model.config.model_type == "phi":
 | 
			
		||||
        from ipex_llm.transformers.models.phi import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    if model.config.model_type in ["phi3", "phi3_v"]:
 | 
			
		||||
    elif model.config.model_type in ["phi3", "phi3_v"]:
 | 
			
		||||
        from ipex_llm.transformers.models.phi3 import pre_compute_inv_freq
 | 
			
		||||
        model.apply(pre_compute_inv_freq)
 | 
			
		||||
        from ipex_llm.transformers.models.phi3 import split_mlp
 | 
			
		||||
        model.apply(split_mlp)
 | 
			
		||||
    # for qwen2
 | 
			
		||||
    if model.config.model_type == "qwen2":
 | 
			
		||||
    elif model.config.model_type == "qwen2":
 | 
			
		||||
        # Skip merge_qkv and padding_mlp if quant_method is 'gptq'
 | 
			
		||||
        should_apply_merge_qkv = (
 | 
			
		||||
            not hasattr(model.config, "quantization_config") or
 | 
			
		||||
| 
						 | 
				
			
			@ -994,51 +995,51 @@ def _optimize_pre(model, qtype=None):
 | 
			
		|||
            if qtype != ggml_tensor_qtype["fp6"]:
 | 
			
		||||
                from ipex_llm.transformers.models.qwen2 import padding_mlp
 | 
			
		||||
                model.apply(padding_mlp)
 | 
			
		||||
    if model.config.model_type == "qwen2_moe":
 | 
			
		||||
    elif model.config.model_type == "qwen2_moe":
 | 
			
		||||
        from ipex_llm.transformers.models.qwen2_moe import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    if model.config.model_type == "qwen2_audio":
 | 
			
		||||
    elif model.config.model_type == "qwen2_audio":
 | 
			
		||||
        from ipex_llm.transformers.models.qwen2 import merge_qkv
 | 
			
		||||
        model.language_model.apply(merge_qkv)
 | 
			
		||||
    if model.config.model_type == "qwen2_vl":
 | 
			
		||||
    elif model.config.model_type == "qwen2_vl":
 | 
			
		||||
        from ipex_llm.transformers.models.qwen2_vl import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    if model.config.model_type == "stablelm":
 | 
			
		||||
    elif model.config.model_type == "stablelm":
 | 
			
		||||
        # For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
 | 
			
		||||
        from ipex_llm.transformers.models.stablelm import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    # for internlm
 | 
			
		||||
    if model.config.model_type == "internlm":
 | 
			
		||||
    elif model.config.model_type == "internlm":
 | 
			
		||||
        from ipex_llm.transformers.models.internlm import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    # for internlm-xcomposer2-vl
 | 
			
		||||
    if model.config.model_type == "internlmxcomposer2":
 | 
			
		||||
    elif model.config.model_type == "internlmxcomposer2":
 | 
			
		||||
        from ipex_llm.transformers.models.internlm import pre_process_attn_and_mlp
 | 
			
		||||
        model.apply(pre_process_attn_and_mlp)
 | 
			
		||||
    if model.config.model_type == "internvl_chat":
 | 
			
		||||
    elif model.config.model_type == "internvl_chat":
 | 
			
		||||
        _optimize_pre(model.language_model, qtype=qtype)
 | 
			
		||||
    if model.config.model_type == "gemma":
 | 
			
		||||
    elif model.config.model_type == "gemma":
 | 
			
		||||
        from ipex_llm.transformers.models.gemma import merge_qkv, pre_compute_inv_freq
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
        model.apply(pre_compute_inv_freq)
 | 
			
		||||
    if model.config.model_type == "gemma2":
 | 
			
		||||
    elif model.config.model_type == "gemma2":
 | 
			
		||||
        from ipex_llm.transformers.models.gemma2 import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    if model.config.model_type == "llama":
 | 
			
		||||
    elif model.config.model_type == "llama":
 | 
			
		||||
        from ipex_llm.transformers.models.llama import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    if model.config.model_type == "mllama":
 | 
			
		||||
    elif model.config.model_type == "mllama":
 | 
			
		||||
        from ipex_llm.transformers.models.mllama import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    if model.config.model_type == "minicpm":
 | 
			
		||||
    elif model.config.model_type == "minicpm":
 | 
			
		||||
        from ipex_llm.transformers.models.minicpm import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    if model.config.model_type == "minicpm3":
 | 
			
		||||
    elif model.config.model_type == "minicpm3":
 | 
			
		||||
        from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq
 | 
			
		||||
        model.apply(pre_compute_inv_freq)
 | 
			
		||||
        from ipex_llm.transformers.models.minicpm3 import padding_v_head_dim
 | 
			
		||||
        model.apply(padding_v_head_dim)
 | 
			
		||||
    if model.config.model_type == "minicpmv":
 | 
			
		||||
    elif model.config.model_type == "minicpmv":
 | 
			
		||||
        from ipex_llm.transformers.models.minicpmv import merge_qkv
 | 
			
		||||
        model.vpm.apply(merge_qkv)
 | 
			
		||||
        if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
 | 
			
		||||
| 
						 | 
				
			
			@ -1049,12 +1050,18 @@ def _optimize_pre(model, qtype=None):
 | 
			
		|||
            model.llm.config.model_type = "llama"
 | 
			
		||||
        _optimize_pre(model.llm, qtype=qtype)
 | 
			
		||||
        model.llm.config.model_type = "minicpmv"
 | 
			
		||||
    if model.config.architectures is not None \
 | 
			
		||||
       and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
 | 
			
		||||
    elif model.config.model_type == "chatglm":
 | 
			
		||||
        if hasattr(model.config, 'padded_vocab_size') and model.config.padded_vocab_size == 65024:
 | 
			
		||||
            # chatglm2 and chatglm3
 | 
			
		||||
            from ipex_llm.transformers.models.chatglm2 import split_mlp
 | 
			
		||||
        if hasattr(model.config, 'padded_vocab_size') and \
 | 
			
		||||
           model.config.padded_vocab_size == 65024:
 | 
			
		||||
            model.apply(split_mlp)
 | 
			
		||||
        elif (
 | 
			
		||||
            isinstance(model.config.eos_token_id, list)
 | 
			
		||||
            and hasattr(model.transformer, "vision")
 | 
			
		||||
            and model.config.num_layers != 40
 | 
			
		||||
        ):
 | 
			
		||||
            from ipex_llm.transformers.models.chatglm4v import merge_qkv
 | 
			
		||||
            model.apply(merge_qkv)
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1426,20 +1433,18 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
            # glm4 family
 | 
			
		||||
            modeling_module_name = model.__class__.__module__
 | 
			
		||||
            module = importlib.import_module(modeling_module_name)
 | 
			
		||||
 | 
			
		||||
            from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
 | 
			
		||||
            convert_forward(model, module.RMSNorm, chatglm_rms_norm_forward)
 | 
			
		||||
 | 
			
		||||
            if hasattr(model.transformer, "vision"):
 | 
			
		||||
                # glm4 vision family
 | 
			
		||||
                modeling_module_name = model.transformer.vision.__class__.__module__
 | 
			
		||||
                vision_module = importlib.import_module(modeling_module_name)
 | 
			
		||||
 | 
			
		||||
                from ipex_llm.transformers.models.chatglm4v import chatglm4v_attention_forward
 | 
			
		||||
                from ipex_llm.transformers.models.chatglm4v import chatglm4v_model_forward
 | 
			
		||||
                convert_forward(model, module.SelfAttention, chatglm4v_attention_forward)
 | 
			
		||||
                convert_forward(model, module.ChatGLMModel, chatglm4v_model_forward)
 | 
			
		||||
 | 
			
		||||
                modeling_module_name = model.transformer.vision.__class__.__module__
 | 
			
		||||
                vision_module = importlib.import_module(modeling_module_name)
 | 
			
		||||
                if model.config.num_layers == 40:
 | 
			
		||||
                    # glm-4v-9b
 | 
			
		||||
                    from ipex_llm.transformers.models.chatglm4v import visual_attention_forward
 | 
			
		||||
| 
						 | 
				
			
			@ -1447,8 +1452,11 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
                    convert_forward(model, vision_module.Attention, visual_attention_forward)
 | 
			
		||||
                    convert_forward(model, vision_module.PatchEmbedding, patch_embedding_forward)
 | 
			
		||||
                else:
 | 
			
		||||
                    # todo
 | 
			
		||||
                    pass
 | 
			
		||||
                    from transformers.models.siglip.modeling_siglip import SiglipAttention
 | 
			
		||||
                    from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
 | 
			
		||||
                    convert_forward(model, SiglipAttention, siglip_attention_forward)
 | 
			
		||||
                    from ipex_llm.transformers.models.chatglm4v import vision_model_forward
 | 
			
		||||
                    convert_forward(model, vision_module.VisionModel, vision_model_forward)
 | 
			
		||||
 | 
			
		||||
            elif model.config.num_layers == 40:
 | 
			
		||||
                # glm-4-9b
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,6 +19,7 @@
 | 
			
		|||
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Optional, Tuple, Union
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base
 | 
			
		||||
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
 | 
			
		||||
| 
						 | 
				
			
			@ -339,3 +340,12 @@ def patch_embedding_forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L,
 | 
			
		|||
    x = torch.cat((cls_token, x), dim=1)
 | 
			
		||||
    x += self.position_embedding.weight.unsqueeze(0).to(images.device)
 | 
			
		||||
    return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def merge_qkv(module: torch.nn.Module):
 | 
			
		||||
    merge_qkv_base(module, "SiglipAttention")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def vision_model_forward(self: torch.nn.Module, image: torch.Tensor):
 | 
			
		||||
    vit_output = self.vit(image)
 | 
			
		||||
    return self.adapter(vit_output)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue