also convert SdpaAttention in optimize_model (#12673)
This commit is contained in:
parent
2c23ce2553
commit
c11f5f0fcd
6 changed files with 19 additions and 3 deletions
|
|
@ -1420,6 +1420,7 @@ def _optimize_post(model):
|
||||||
convert_forward(model, module.GlmRMSNorm, rms_norm_forward)
|
convert_forward(model, module.GlmRMSNorm, rms_norm_forward)
|
||||||
convert_forward(model, module.GlmMLP, mlp_silu_forward)
|
convert_forward(model, module.GlmMLP, mlp_silu_forward)
|
||||||
convert_forward(model, module.GlmAttention, glm_attention_forward)
|
convert_forward(model, module.GlmAttention, glm_attention_forward)
|
||||||
|
convert_forward(model, module.GlmSdpaAttention, glm_attention_forward)
|
||||||
glm_model_forward = glm_model_forward_wrapper(module.GlmModel.forward)
|
glm_model_forward = glm_model_forward_wrapper(module.GlmModel.forward)
|
||||||
convert_forward(model, module.GlmModel, glm_model_forward)
|
convert_forward(model, module.GlmModel, glm_model_forward)
|
||||||
|
|
||||||
|
|
@ -1428,10 +1429,12 @@ def _optimize_post(model):
|
||||||
vision_module_name = model.model.vision.__class__.__module__
|
vision_module_name = model.model.vision.__class__.__module__
|
||||||
vision_module = importlib.import_module(vision_module_name)
|
vision_module = importlib.import_module(vision_module_name)
|
||||||
from transformers.models.siglip.modeling_siglip import SiglipAttention
|
from transformers.models.siglip.modeling_siglip import SiglipAttention
|
||||||
|
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
|
||||||
from ipex_llm.transformers.models.chatglm4v import vision_model_forward
|
from ipex_llm.transformers.models.chatglm4v import vision_model_forward
|
||||||
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
|
from ipex_llm.transformers.models.minicpmv import siglip_attention_forward
|
||||||
convert_forward(model, vision_module.VisionModel, vision_model_forward)
|
convert_forward(model, vision_module.VisionModel, vision_model_forward)
|
||||||
convert_forward(model, SiglipAttention, siglip_attention_forward)
|
convert_forward(model, SiglipAttention, siglip_attention_forward)
|
||||||
|
convert_forward(model, SiglipSdpaAttention, siglip_attention_forward)
|
||||||
|
|
||||||
elif "mpt" in model.config.model_type:
|
elif "mpt" in model.config.model_type:
|
||||||
if model.config.architectures is not None:
|
if model.config.architectures is not None:
|
||||||
|
|
@ -1667,8 +1670,10 @@ def _optimize_post(model):
|
||||||
convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward)
|
convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward)
|
||||||
model.visual.get_dtype = MethodType(qwen2_vision_get_dtype, model.visual)
|
model.visual.get_dtype = MethodType(qwen2_vision_get_dtype, model.visual)
|
||||||
convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward)
|
convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward)
|
||||||
|
convert_forward(model, module.VisionSdpaAttention, qwen2_vision_attention_forward)
|
||||||
convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward)
|
convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward)
|
||||||
convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward)
|
convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward)
|
||||||
|
convert_forward(model, module.Qwen2VLSdpaAttention, qwen2_vl_attention_forward)
|
||||||
elif model.config.model_type == "aquila":
|
elif model.config.model_type == "aquila":
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
|
@ -1814,6 +1819,7 @@ def _optimize_post(model):
|
||||||
from ipex_llm.transformers.models.starcoder2 import attention_forward
|
from ipex_llm.transformers.models.starcoder2 import attention_forward
|
||||||
from ipex_llm.transformers.models.starcoder2 import model_forward
|
from ipex_llm.transformers.models.starcoder2 import model_forward
|
||||||
convert_forward(model, module.Starcoder2Attention, attention_forward)
|
convert_forward(model, module.Starcoder2Attention, attention_forward)
|
||||||
|
convert_forward(model, module.Starcoder2SdpaAttention, attention_forward)
|
||||||
convert_forward(model, module.Starcoder2Model, model_forward)
|
convert_forward(model, module.Starcoder2Model, model_forward)
|
||||||
elif model.config.model_type == "phi":
|
elif model.config.model_type == "phi":
|
||||||
# for phi-2
|
# for phi-2
|
||||||
|
|
@ -1829,6 +1835,7 @@ def _optimize_post(model):
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.phi3 import attention_forward
|
from ipex_llm.transformers.models.phi3 import attention_forward
|
||||||
convert_forward(model, module.Phi3Attention, attention_forward)
|
convert_forward(model, module.Phi3Attention, attention_forward)
|
||||||
|
convert_forward(model, module.Phi3SdpaAttention, attention_forward)
|
||||||
from ipex_llm.transformers.models.phi3 import mlp_forward
|
from ipex_llm.transformers.models.phi3 import mlp_forward
|
||||||
convert_forward(model, module.Phi3MLP, mlp_forward)
|
convert_forward(model, module.Phi3MLP, mlp_forward)
|
||||||
from ipex_llm.transformers.models.common import rms_norm_forward
|
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||||
|
|
@ -1872,6 +1879,8 @@ def _optimize_post(model):
|
||||||
module.StableLmAttention,
|
module.StableLmAttention,
|
||||||
stablelm_attention_forward
|
stablelm_attention_forward
|
||||||
)
|
)
|
||||||
|
if hasattr(module, "StableLmSdpaAttention"):
|
||||||
|
convert_forward(model, module.StableLmSdpaAttention, stablelm_attention_forward)
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.StableLmMLP,
|
module.StableLmMLP,
|
||||||
mlp_silu_forward)
|
mlp_silu_forward)
|
||||||
|
|
@ -1886,6 +1895,7 @@ def _optimize_post(model):
|
||||||
from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper
|
from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper
|
||||||
from ipex_llm.transformers.models.minicpm import minicpm_decoder_layer_forward
|
from ipex_llm.transformers.models.minicpm import minicpm_decoder_layer_forward
|
||||||
convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward)
|
convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward)
|
||||||
|
convert_forward(model, module.MiniCPMSdpaAttention, minicpm_attention_forward)
|
||||||
convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
|
convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
|
||||||
convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
|
convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
|
||||||
convert_forward(model, module.MiniCPMDecoderLayer, minicpm_decoder_layer_forward)
|
convert_forward(model, module.MiniCPMDecoderLayer, minicpm_decoder_layer_forward)
|
||||||
|
|
@ -1901,6 +1911,7 @@ def _optimize_post(model):
|
||||||
convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
|
convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward)
|
||||||
convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
|
convert_forward(model, module.MiniCPMMLP, mlp_silu_forward)
|
||||||
convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward)
|
convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward)
|
||||||
|
convert_forward(model, module.MiniCPMSdpaAttention, minicpm3_attention_forward)
|
||||||
minicpm3_model_forward = minicpm3_model_forward_wrapper(module.MiniCPM3Model.forward)
|
minicpm3_model_forward = minicpm3_model_forward_wrapper(module.MiniCPM3Model.forward)
|
||||||
convert_forward(model, module.MiniCPM3Model, minicpm3_model_forward)
|
convert_forward(model, module.MiniCPM3Model, minicpm3_model_forward)
|
||||||
elif model.config.model_type == "minicpmv":
|
elif model.config.model_type == "minicpmv":
|
||||||
|
|
|
||||||
|
|
@ -301,6 +301,7 @@ def patch_embedding_forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L,
|
||||||
|
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
merge_qkv_base(module, "SiglipAttention")
|
merge_qkv_base(module, "SiglipAttention")
|
||||||
|
merge_qkv_base(module, "SiglipSdpaAttention")
|
||||||
|
|
||||||
|
|
||||||
def vision_model_forward(self: torch.nn.Module, image: torch.Tensor):
|
def vision_model_forward(self: torch.nn.Module, image: torch.Tensor):
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ import torch
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.models.glm.modeling_glm import GlmAttention
|
||||||
from transformers.models.glm.modeling_glm import apply_rotary_pos_emb
|
from transformers.models.glm.modeling_glm import apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
||||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||||
|
|
@ -46,8 +47,9 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
merge_qkv_base(module, "GlmAttention")
|
merge_qkv_base(module, GlmAttention)
|
||||||
merge_qkv_base(module, "SiglipAttention")
|
merge_qkv_base(module, "SiglipAttention")
|
||||||
|
merge_qkv_base(module, "SiglipSdpaAttention")
|
||||||
|
|
||||||
|
|
||||||
def split_mlp(module: torch.nn.Module):
|
def split_mlp(module: torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,7 @@ def llama_model_forward(
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
return merge_qkv_base(module, LlamaAttention)
|
merge_qkv_base(module, LlamaAttention)
|
||||||
|
|
||||||
|
|
||||||
def llama_attention_forward(
|
def llama_attention_forward(
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,8 @@ from transformers.cache_utils import Cache
|
||||||
|
|
||||||
|
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
return merge_qkv_base(module, "MiniCPMAttention")
|
merge_qkv_base(module, "MiniCPMAttention")
|
||||||
|
merge_qkv_base(module, "MiniCPMSdpaAttention")
|
||||||
|
|
||||||
|
|
||||||
def apply_residual_scale(module: torch.nn.Module):
|
def apply_residual_scale(module: torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ from transformers.generation.logits_process import RepetitionPenaltyLogitsProces
|
||||||
# MiniCPM-V-2_5 and MiniCPM-V-2_6
|
# MiniCPM-V-2_5 and MiniCPM-V-2_6
|
||||||
def merge_qkv(module: torch.nn.Module):
|
def merge_qkv(module: torch.nn.Module):
|
||||||
merge_qkv_base(module, "SiglipAttention")
|
merge_qkv_base(module, "SiglipAttention")
|
||||||
|
merge_qkv_base(module, "SiglipSdpaAttention")
|
||||||
merge_qkv_base(module, "Idefics2VisionAttention")
|
merge_qkv_base(module, "Idefics2VisionAttention")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue