also convert SdpaAttention in optimize_model (#12673)

This commit is contained in:
Yishuo Wang 2025-01-08 16:48:03 +08:00 committed by GitHub
parent 2c23ce2553
commit c11f5f0fcd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 19 additions and 3 deletions

View file

@ -1420,6 +1420,7 @@ def _optimize_post(model):
convert_forward(model, module.GlmRMSNorm, rms_norm_forward)
convert_forward(model, module.GlmMLP, mlp_silu_forward)
convert_forward(model, module.GlmAttention, glm_attention_forward)
convert_forward(model, module.GlmSdpaAttention, glm_attention_forward)
glm_model_forward = glm_model_forward_wrapper(module.GlmModel.forward)
convert_forward(model, module.GlmModel, glm_model_forward)
@ -1428,10 +1429,12 @@ def _optimize_post(model):
vision_module_name = model.model.vision.__class__.__module__
vision_module = importlib.import_module(vision_module_name)
from transformers.models.siglip.modeling_siglip import SiglipAttention
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
from ipex_llm.transformers.models.chatglm4v import vision_model_forward
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
convert_forward(model, vision_module.VisionModel, vision_model_forward)
convert_forward(model, SiglipAttention, siglip_attention_forward)
convert_forward(model, SiglipSdpaAttention, siglip_attention_forward)
elif "mpt" in model.config.model_type:
if model.config.architectures is not None:
@ -1667,8 +1670,10 @@ def _optimize_post(model):
convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward)
model.visual.get_dtype = MethodType(qwen2_vision_get_dtype, model.visual)
convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward)
convert_forward(model, module.VisionSdpaAttention, qwen2_vision_attention_forward)
convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward)
convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward)
convert_forward(model, module.Qwen2VLSdpaAttention, qwen2_vl_attention_forward)
elif model.config.model_type == "aquila":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
@ -1814,6 +1819,7 @@ def _optimize_post(model):
from ipex_llm.transformers.models.starcoder2 import attention_forward
from ipex_llm.transformers.models.starcoder2 import model_forward
convert_forward(model, module.Starcoder2Attention, attention_forward)
convert_forward(model, module.Starcoder2SdpaAttention, attention_forward)
convert_forward(model, module.Starcoder2Model, model_forward)
elif model.config.model_type == "phi":
# for phi-2
@ -1829,6 +1835,7 @@ def _optimize_post(model):
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.phi3 import attention_forward
convert_forward(model, module.Phi3Attention, attention_forward)
convert_forward(model, module.Phi3SdpaAttention, attention_forward)
from ipex_llm.transformers.models.phi3 import mlp_forward
convert_forward(model, module.Phi3MLP, mlp_forward)
from ipex_llm.transformers.models.common import rms_norm_forward
@ -1872,6 +1879,8 @@ def _optimize_post(model):
module.StableLmAttention,
stablelm_attention_forward
)
if hasattr(module, "StableLmSdpaAttention"):
convert_forward(model, module.StableLmSdpaAttention, stablelm_attention_forward)
convert_forward(model,
module.StableLmMLP,
mlp_silu_forward)
@ -1886,6 +1895,7 @@ def _optimize_post(model):
from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper
from ipex_llm.transformers.models.minicpm import minicpm_decoder_layer_forward
convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward)
convert_forward(model, module.MiniCPMSdpaAttention, minicpm_attention_forward)
convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
convert_forward(model, module.MiniCPMDecoderLayer, minicpm_decoder_layer_forward)
@ -1901,6 +1911,7 @@ def _optimize_post(model):
convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward)
convert_forward(model, module.MiniCPMSdpaAttention, minicpm3_attention_forward)
minicpm3_model_forward = minicpm3_model_forward_wrapper(module.MiniCPM3Model.forward)
convert_forward(model, module.MiniCPM3Model, minicpm3_model_forward)
elif model.config.model_type == "minicpmv":

View file

@ -301,6 +301,7 @@ def patch_embedding_forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L,
def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, "SiglipAttention")
merge_qkv_base(module, "SiglipSdpaAttention")
def vision_model_forward(self: torch.nn.Module, image: torch.Tensor):

View file

@ -37,6 +37,7 @@ import torch
from typing import Optional, Tuple
from transformers.cache_utils import Cache
from transformers.models.glm.modeling_glm import GlmAttention
from transformers.models.glm.modeling_glm import apply_rotary_pos_emb
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
from ipex_llm.transformers.models.common import merge_qkv_base
@ -46,8 +47,9 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache
def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, "GlmAttention")
merge_qkv_base(module, GlmAttention)
merge_qkv_base(module, "SiglipAttention")
merge_qkv_base(module, "SiglipSdpaAttention")
def split_mlp(module: torch.nn.Module):

View file

@ -116,7 +116,7 @@ def llama_model_forward(
def merge_qkv(module: torch.nn.Module):
return merge_qkv_base(module, LlamaAttention)
merge_qkv_base(module, LlamaAttention)
def llama_attention_forward(

View file

@ -51,7 +51,8 @@ from transformers.cache_utils import Cache
def merge_qkv(module: torch.nn.Module):
return merge_qkv_base(module, "MiniCPMAttention")
merge_qkv_base(module, "MiniCPMAttention")
merge_qkv_base(module, "MiniCPMSdpaAttention")
def apply_residual_scale(module: torch.nn.Module):

View file

@ -36,6 +36,7 @@ from transformers.generation.logits_process import RepetitionPenaltyLogitsProces
# MiniCPM-V-2_5 and MiniCPM-V-2_6
def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, "SiglipAttention")
merge_qkv_base(module, "SiglipSdpaAttention")
merge_qkv_base(module, "Idefics2VisionAttention")