From 17a0beb21f0d697e2f1fb1ec8b53788852c86f49 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 16 Aug 2024 11:11:35 +0800 Subject: [PATCH] optimize qwen2-audio again (#11825) --- .../llm/src/ipex_llm/transformers/convert.py | 3 ++ .../src/ipex_llm/transformers/models/qwen2.py | 29 ++++--------------- 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 1adc00b9..2ce495f0 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -830,6 +830,9 @@ def _optimize_pre(model, qtype=None): if model.config.model_type == "qwen2_moe": from ipex_llm.transformers.models.qwen2_moe import merge_qkv model.apply(merge_qkv) + if model.config.model_type == "qwen2_audio": + from ipex_llm.transformers.models.qwen2 import merge_qkv + model.language_model.apply(merge_qkv) if model.config.model_type == "stablelm": # For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b from ipex_llm.transformers.models.stablelm import merge_qkv diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 32d838cb..8368aa58 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -45,6 +45,7 @@ import torch from torch.nn import CrossEntropyLoss from torch.nn.functional import scaled_dot_product_attention as sdpa +from ipex_llm.transformers.models.common import merge_qkv_base from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \ @@ -465,30 +466,10 @@ def qwen2_causal_lm_forward( def merge_qkv(module: torch.nn.Module): - if isinstance(module, Qwen2Attention): - new_weight = torch.cat([ - module.q_proj.weight.data, - module.k_proj.weight.data, - module.v_proj.weight.data, - ], dim=0) - new_bias = torch.cat([ - module.q_proj.bias.data, - module.k_proj.bias.data, - module.v_proj.bias.data, - ], dim=-1) - - qkv_proj = torch.nn.Linear(0, 0, bias=True) - qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False) - qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False) - qkv_proj.in_features = new_weight.size(1) - qkv_proj.out_features = new_weight.size(0) - module.qkv_proj = qkv_proj - - del module.q_proj, module.k_proj, module.v_proj - - if os.environ.get("IPEX_LLM_LOW_MEM", None) == "1": - del module.rotary_emb.cos_cached - del module.rotary_emb.sin_cached + merge_qkv_base(module, Qwen2Attention) + if isinstance(module, Qwen2Attention) and os.environ.get("IPEX_LLM_LOW_MEM", None) == "1": + del module.rotary_emb.cos_cached + del module.rotary_emb.sin_cached def padding_mlp(module: torch.nn.Module):