also convert SdpaAttention in optimize_model (#12673)
This commit is contained in:
		
							parent
							
								
									2c23ce2553
								
							
						
					
					
						commit
						c11f5f0fcd
					
				
					 6 changed files with 19 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -1420,6 +1420,7 @@ def _optimize_post(model):
 | 
			
		|||
        convert_forward(model, module.GlmRMSNorm, rms_norm_forward)
 | 
			
		||||
        convert_forward(model, module.GlmMLP, mlp_silu_forward)
 | 
			
		||||
        convert_forward(model, module.GlmAttention, glm_attention_forward)
 | 
			
		||||
        convert_forward(model, module.GlmSdpaAttention, glm_attention_forward)
 | 
			
		||||
        glm_model_forward = glm_model_forward_wrapper(module.GlmModel.forward)
 | 
			
		||||
        convert_forward(model, module.GlmModel, glm_model_forward)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1428,10 +1429,12 @@ def _optimize_post(model):
 | 
			
		|||
            vision_module_name = model.model.vision.__class__.__module__
 | 
			
		||||
            vision_module = importlib.import_module(vision_module_name)
 | 
			
		||||
            from transformers.models.siglip.modeling_siglip import SiglipAttention
 | 
			
		||||
            from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
 | 
			
		||||
            from ipex_llm.transformers.models.chatglm4v import vision_model_forward
 | 
			
		||||
            from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
 | 
			
		||||
            convert_forward(model, vision_module.VisionModel, vision_model_forward)
 | 
			
		||||
            convert_forward(model, SiglipAttention, siglip_attention_forward)
 | 
			
		||||
            convert_forward(model, SiglipSdpaAttention, siglip_attention_forward)
 | 
			
		||||
 | 
			
		||||
    elif "mpt" in model.config.model_type:
 | 
			
		||||
        if model.config.architectures is not None:
 | 
			
		||||
| 
						 | 
				
			
			@ -1667,8 +1670,10 @@ def _optimize_post(model):
 | 
			
		|||
        convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward)
 | 
			
		||||
        model.visual.get_dtype = MethodType(qwen2_vision_get_dtype, model.visual)
 | 
			
		||||
        convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward)
 | 
			
		||||
        convert_forward(model, module.VisionSdpaAttention, qwen2_vision_attention_forward)
 | 
			
		||||
        convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward)
 | 
			
		||||
        convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward)
 | 
			
		||||
        convert_forward(model, module.Qwen2VLSdpaAttention, qwen2_vl_attention_forward)
 | 
			
		||||
    elif model.config.model_type == "aquila":
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			@ -1814,6 +1819,7 @@ def _optimize_post(model):
 | 
			
		|||
        from ipex_llm.transformers.models.starcoder2 import attention_forward
 | 
			
		||||
        from ipex_llm.transformers.models.starcoder2 import model_forward
 | 
			
		||||
        convert_forward(model, module.Starcoder2Attention, attention_forward)
 | 
			
		||||
        convert_forward(model, module.Starcoder2SdpaAttention, attention_forward)
 | 
			
		||||
        convert_forward(model, module.Starcoder2Model, model_forward)
 | 
			
		||||
    elif model.config.model_type == "phi":
 | 
			
		||||
        # for phi-2
 | 
			
		||||
| 
						 | 
				
			
			@ -1829,6 +1835,7 @@ def _optimize_post(model):
 | 
			
		|||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.phi3 import attention_forward
 | 
			
		||||
        convert_forward(model, module.Phi3Attention, attention_forward)
 | 
			
		||||
        convert_forward(model, module.Phi3SdpaAttention, attention_forward)
 | 
			
		||||
        from ipex_llm.transformers.models.phi3 import mlp_forward
 | 
			
		||||
        convert_forward(model, module.Phi3MLP, mlp_forward)
 | 
			
		||||
        from ipex_llm.transformers.models.common import rms_norm_forward
 | 
			
		||||
| 
						 | 
				
			
			@ -1872,6 +1879,8 @@ def _optimize_post(model):
 | 
			
		|||
                        module.StableLmAttention,
 | 
			
		||||
                        stablelm_attention_forward
 | 
			
		||||
                        )
 | 
			
		||||
        if hasattr(module, "StableLmSdpaAttention"):
 | 
			
		||||
            convert_forward(model, module.StableLmSdpaAttention, stablelm_attention_forward)
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.StableLmMLP,
 | 
			
		||||
                        mlp_silu_forward)
 | 
			
		||||
| 
						 | 
				
			
			@ -1886,6 +1895,7 @@ def _optimize_post(model):
 | 
			
		|||
        from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper
 | 
			
		||||
        from ipex_llm.transformers.models.minicpm import minicpm_decoder_layer_forward
 | 
			
		||||
        convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward)
 | 
			
		||||
        convert_forward(model, module.MiniCPMSdpaAttention, minicpm_attention_forward)
 | 
			
		||||
        convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
 | 
			
		||||
        convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
 | 
			
		||||
        convert_forward(model, module.MiniCPMDecoderLayer, minicpm_decoder_layer_forward)
 | 
			
		||||
| 
						 | 
				
			
			@ -1901,6 +1911,7 @@ def _optimize_post(model):
 | 
			
		|||
        convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
 | 
			
		||||
        convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
 | 
			
		||||
        convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward)
 | 
			
		||||
        convert_forward(model, module.MiniCPMSdpaAttention, minicpm3_attention_forward)
 | 
			
		||||
        minicpm3_model_forward = minicpm3_model_forward_wrapper(module.MiniCPM3Model.forward)
 | 
			
		||||
        convert_forward(model, module.MiniCPM3Model, minicpm3_model_forward)
 | 
			
		||||
    elif model.config.model_type == "minicpmv":
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -301,6 +301,7 @@ def patch_embedding_forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L,
 | 
			
		|||
 | 
			
		||||
def merge_qkv(module: torch.nn.Module):
 | 
			
		||||
    merge_qkv_base(module, "SiglipAttention")
 | 
			
		||||
    merge_qkv_base(module, "SiglipSdpaAttention")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def vision_model_forward(self: torch.nn.Module, image: torch.Tensor):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -37,6 +37,7 @@ import torch
 | 
			
		|||
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
from transformers.models.glm.modeling_glm import GlmAttention
 | 
			
		||||
from transformers.models.glm.modeling_glm import apply_rotary_pos_emb
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base
 | 
			
		||||
| 
						 | 
				
			
			@ -46,8 +47,9 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def merge_qkv(module: torch.nn.Module):
 | 
			
		||||
    merge_qkv_base(module, "GlmAttention")
 | 
			
		||||
    merge_qkv_base(module, GlmAttention)
 | 
			
		||||
    merge_qkv_base(module, "SiglipAttention")
 | 
			
		||||
    merge_qkv_base(module, "SiglipSdpaAttention")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_mlp(module: torch.nn.Module):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -116,7 +116,7 @@ def llama_model_forward(
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def merge_qkv(module: torch.nn.Module):
 | 
			
		||||
    return merge_qkv_base(module, LlamaAttention)
 | 
			
		||||
    merge_qkv_base(module, LlamaAttention)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def llama_attention_forward(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -51,7 +51,8 @@ from transformers.cache_utils import Cache
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def merge_qkv(module: torch.nn.Module):
 | 
			
		||||
    return merge_qkv_base(module, "MiniCPMAttention")
 | 
			
		||||
    merge_qkv_base(module, "MiniCPMAttention")
 | 
			
		||||
    merge_qkv_base(module, "MiniCPMSdpaAttention")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_residual_scale(module: torch.nn.Module):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -36,6 +36,7 @@ from transformers.generation.logits_process import RepetitionPenaltyLogitsProces
 | 
			
		|||
# MiniCPM-V-2_5 and MiniCPM-V-2_6
 | 
			
		||||
def merge_qkv(module: torch.nn.Module):
 | 
			
		||||
    merge_qkv_base(module, "SiglipAttention")
 | 
			
		||||
    merge_qkv_base(module, "SiglipSdpaAttention")
 | 
			
		||||
    merge_qkv_base(module, "Idefics2VisionAttention")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue