From c11f5f0fcd76ed71638eae290ca640d31b692202 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 8 Jan 2025 16:48:03 +0800 Subject: [PATCH] also convert SdpaAttention in optimize_model (#12673) --- python/llm/src/ipex_llm/transformers/convert.py | 11 +++++++++++ .../llm/src/ipex_llm/transformers/models/chatglm4v.py | 1 + python/llm/src/ipex_llm/transformers/models/glm.py | 4 +++- python/llm/src/ipex_llm/transformers/models/llama.py | 2 +- .../llm/src/ipex_llm/transformers/models/minicpm.py | 3 ++- .../llm/src/ipex_llm/transformers/models/minicpmv.py | 1 + 6 files changed, 19 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 6f78b9a8..655d666f 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1420,6 +1420,7 @@ def _optimize_post(model): convert_forward(model, module.GlmRMSNorm, rms_norm_forward) convert_forward(model, module.GlmMLP, mlp_silu_forward) convert_forward(model, module.GlmAttention, glm_attention_forward) + convert_forward(model, module.GlmSdpaAttention, glm_attention_forward) glm_model_forward = glm_model_forward_wrapper(module.GlmModel.forward) convert_forward(model, module.GlmModel, glm_model_forward) @@ -1428,10 +1429,12 @@ def _optimize_post(model): vision_module_name = model.model.vision.__class__.__module__ vision_module = importlib.import_module(vision_module_name) from transformers.models.siglip.modeling_siglip import SiglipAttention + from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention from ipex_llm.transformers.models.chatglm4v import vision_model_forward from ipex_llm.transformers.models.minicpmv import siglip_attention_forward convert_forward(model, vision_module.VisionModel, vision_model_forward) convert_forward(model, SiglipAttention, siglip_attention_forward) + convert_forward(model, SiglipSdpaAttention, siglip_attention_forward) elif "mpt" in model.config.model_type: if model.config.architectures is not None: @@ -1667,8 +1670,10 @@ def _optimize_post(model): convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward) model.visual.get_dtype = MethodType(qwen2_vision_get_dtype, model.visual) convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward) + convert_forward(model, module.VisionSdpaAttention, qwen2_vision_attention_forward) convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward) convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward) + convert_forward(model, module.Qwen2VLSdpaAttention, qwen2_vl_attention_forward) elif model.config.model_type == "aquila": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) @@ -1814,6 +1819,7 @@ def _optimize_post(model): from ipex_llm.transformers.models.starcoder2 import attention_forward from ipex_llm.transformers.models.starcoder2 import model_forward convert_forward(model, module.Starcoder2Attention, attention_forward) + convert_forward(model, module.Starcoder2SdpaAttention, attention_forward) convert_forward(model, module.Starcoder2Model, model_forward) elif model.config.model_type == "phi": # for phi-2 @@ -1829,6 +1835,7 @@ def _optimize_post(model): module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.phi3 import attention_forward convert_forward(model, module.Phi3Attention, attention_forward) + convert_forward(model, module.Phi3SdpaAttention, attention_forward) from ipex_llm.transformers.models.phi3 import mlp_forward convert_forward(model, module.Phi3MLP, mlp_forward) from ipex_llm.transformers.models.common import rms_norm_forward @@ -1872,6 +1879,8 @@ def _optimize_post(model): module.StableLmAttention, stablelm_attention_forward ) + if hasattr(module, "StableLmSdpaAttention"): + convert_forward(model, module.StableLmSdpaAttention, stablelm_attention_forward) convert_forward(model, module.StableLmMLP, mlp_silu_forward) @@ -1886,6 +1895,7 @@ def _optimize_post(model): from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper from ipex_llm.transformers.models.minicpm import minicpm_decoder_layer_forward convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward) + convert_forward(model, module.MiniCPMSdpaAttention, minicpm_attention_forward) convert_forward(model, module.MiniCPMMLP, mlp_silu_forward) convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward) convert_forward(model, module.MiniCPMDecoderLayer, minicpm_decoder_layer_forward) @@ -1901,6 +1911,7 @@ def _optimize_post(model): convert_forward(model, module.MiniCPMRMSNorm, rms_norm_forward) convert_forward(model, module.MiniCPMMLP, mlp_silu_forward) convert_forward(model, module.MiniCPMAttention, minicpm3_attention_forward) + convert_forward(model, module.MiniCPMSdpaAttention, minicpm3_attention_forward) minicpm3_model_forward = minicpm3_model_forward_wrapper(module.MiniCPM3Model.forward) convert_forward(model, module.MiniCPM3Model, minicpm3_model_forward) elif model.config.model_type == "minicpmv": diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4v.py b/python/llm/src/ipex_llm/transformers/models/chatglm4v.py index 86968463..10028bca 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4v.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4v.py @@ -301,6 +301,7 @@ def patch_embedding_forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, def merge_qkv(module: torch.nn.Module): merge_qkv_base(module, "SiglipAttention") + merge_qkv_base(module, "SiglipSdpaAttention") def vision_model_forward(self: torch.nn.Module, image: torch.Tensor): diff --git a/python/llm/src/ipex_llm/transformers/models/glm.py b/python/llm/src/ipex_llm/transformers/models/glm.py index 4d29835c..f0a2d17a 100644 --- a/python/llm/src/ipex_llm/transformers/models/glm.py +++ b/python/llm/src/ipex_llm/transformers/models/glm.py @@ -37,6 +37,7 @@ import torch from typing import Optional, Tuple from transformers.cache_utils import Cache +from transformers.models.glm.modeling_glm import GlmAttention from transformers.models.glm.modeling_glm import apply_rotary_pos_emb from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache from ipex_llm.transformers.models.common import merge_qkv_base @@ -46,8 +47,9 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache def merge_qkv(module: torch.nn.Module): - merge_qkv_base(module, "GlmAttention") + merge_qkv_base(module, GlmAttention) merge_qkv_base(module, "SiglipAttention") + merge_qkv_base(module, "SiglipSdpaAttention") def split_mlp(module: torch.nn.Module): diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 7160f46d..610f1ac0 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -116,7 +116,7 @@ def llama_model_forward( def merge_qkv(module: torch.nn.Module): - return merge_qkv_base(module, LlamaAttention) + merge_qkv_base(module, LlamaAttention) def llama_attention_forward( diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm.py b/python/llm/src/ipex_llm/transformers/models/minicpm.py index f3c45425..532e992d 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm.py @@ -51,7 +51,8 @@ from transformers.cache_utils import Cache def merge_qkv(module: torch.nn.Module): - return merge_qkv_base(module, "MiniCPMAttention") + merge_qkv_base(module, "MiniCPMAttention") + merge_qkv_base(module, "MiniCPMSdpaAttention") def apply_residual_scale(module: torch.nn.Module): diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index 9e0f1085..f64fc37c 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -36,6 +36,7 @@ from transformers.generation.logits_process import RepetitionPenaltyLogitsProces # MiniCPM-V-2_5 and MiniCPM-V-2_6 def merge_qkv(module: torch.nn.Module): merge_qkv_base(module, "SiglipAttention") + merge_qkv_base(module, "SiglipSdpaAttention") merge_qkv_base(module, "Idefics2VisionAttention")