optimize qwen2-audio again (#11825)
This commit is contained in:
parent
6a8d07ddb4
commit
17a0beb21f
2 changed files with 8 additions and 24 deletions
|
|
@ -830,6 +830,9 @@ def _optimize_pre(model, qtype=None):
|
|||
if model.config.model_type == "qwen2_moe":
|
||||
from ipex_llm.transformers.models.qwen2_moe import merge_qkv
|
||||
model.apply(merge_qkv)
|
||||
if model.config.model_type == "qwen2_audio":
|
||||
from ipex_llm.transformers.models.qwen2 import merge_qkv
|
||||
model.language_model.apply(merge_qkv)
|
||||
if model.config.model_type == "stablelm":
|
||||
# For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
|
||||
from ipex_llm.transformers.models.stablelm import merge_qkv
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ import torch
|
|||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
||||
|
||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||
from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
|
||||
|
|
@ -465,30 +466,10 @@ def qwen2_causal_lm_forward(
|
|||
|
||||
|
||||
def merge_qkv(module: torch.nn.Module):
|
||||
if isinstance(module, Qwen2Attention):
|
||||
new_weight = torch.cat([
|
||||
module.q_proj.weight.data,
|
||||
module.k_proj.weight.data,
|
||||
module.v_proj.weight.data,
|
||||
], dim=0)
|
||||
new_bias = torch.cat([
|
||||
module.q_proj.bias.data,
|
||||
module.k_proj.bias.data,
|
||||
module.v_proj.bias.data,
|
||||
], dim=-1)
|
||||
|
||||
qkv_proj = torch.nn.Linear(0, 0, bias=True)
|
||||
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
|
||||
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
|
||||
qkv_proj.in_features = new_weight.size(1)
|
||||
qkv_proj.out_features = new_weight.size(0)
|
||||
module.qkv_proj = qkv_proj
|
||||
|
||||
del module.q_proj, module.k_proj, module.v_proj
|
||||
|
||||
if os.environ.get("IPEX_LLM_LOW_MEM", None) == "1":
|
||||
del module.rotary_emb.cos_cached
|
||||
del module.rotary_emb.sin_cached
|
||||
merge_qkv_base(module, Qwen2Attention)
|
||||
if isinstance(module, Qwen2Attention) and os.environ.get("IPEX_LLM_LOW_MEM", None) == "1":
|
||||
del module.rotary_emb.cos_cached
|
||||
del module.rotary_emb.sin_cached
|
||||
|
||||
|
||||
def padding_mlp(module: torch.nn.Module):
|
||||
|
|
|
|||
Loading…
Reference in a new issue