optimize qwen2-audio again (#11825)

This commit is contained in:
Yishuo Wang 2024-08-16 11:11:35 +08:00 committed by GitHub
parent 6a8d07ddb4
commit 17a0beb21f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 24 deletions

View file

@ -830,6 +830,9 @@ def _optimize_pre(model, qtype=None):
if model.config.model_type == "qwen2_moe":
from ipex_llm.transformers.models.qwen2_moe import merge_qkv
model.apply(merge_qkv)
if model.config.model_type == "qwen2_audio":
from ipex_llm.transformers.models.qwen2 import merge_qkv
model.language_model.apply(merge_qkv)
if model.config.model_type == "stablelm":
# For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
from ipex_llm.transformers.models.stablelm import merge_qkv

View file

@ -45,6 +45,7 @@ import torch
from torch.nn import CrossEntropyLoss
from torch.nn.functional import scaled_dot_product_attention as sdpa
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
@ -465,30 +466,10 @@ def qwen2_causal_lm_forward(
def merge_qkv(module: torch.nn.Module):
if isinstance(module, Qwen2Attention):
new_weight = torch.cat([
module.q_proj.weight.data,
module.k_proj.weight.data,
module.v_proj.weight.data,
], dim=0)
new_bias = torch.cat([
module.q_proj.bias.data,
module.k_proj.bias.data,
module.v_proj.bias.data,
], dim=-1)
qkv_proj = torch.nn.Linear(0, 0, bias=True)
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
qkv_proj.in_features = new_weight.size(1)
qkv_proj.out_features = new_weight.size(0)
module.qkv_proj = qkv_proj
del module.q_proj, module.k_proj, module.v_proj
if os.environ.get("IPEX_LLM_LOW_MEM", None) == "1":
del module.rotary_emb.cos_cached
del module.rotary_emb.sin_cached
merge_qkv_base(module, Qwen2Attention)
if isinstance(module, Qwen2Attention) and os.environ.get("IPEX_LLM_LOW_MEM", None) == "1":
del module.rotary_emb.cos_cached
del module.rotary_emb.sin_cached
def padding_mlp(module: torch.nn.Module):