optimize qwen2-audio again (#11825)
This commit is contained in:
		
							parent
							
								
									6a8d07ddb4
								
							
						
					
					
						commit
						17a0beb21f
					
				
					 2 changed files with 8 additions and 24 deletions
				
			
		| 
						 | 
				
			
			@ -830,6 +830,9 @@ def _optimize_pre(model, qtype=None):
 | 
			
		|||
    if model.config.model_type == "qwen2_moe":
 | 
			
		||||
        from ipex_llm.transformers.models.qwen2_moe import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
    if model.config.model_type == "qwen2_audio":
 | 
			
		||||
        from ipex_llm.transformers.models.qwen2 import merge_qkv
 | 
			
		||||
        model.language_model.apply(merge_qkv)
 | 
			
		||||
    if model.config.model_type == "stablelm":
 | 
			
		||||
        # For stablelm-zephyr-3b and stablelm-2-zephyr-1_6b
 | 
			
		||||
        from ipex_llm.transformers.models.stablelm import merge_qkv
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -45,6 +45,7 @@ import torch
 | 
			
		|||
from torch.nn import CrossEntropyLoss
 | 
			
		||||
from torch.nn.functional import scaled_dot_product_attention as sdpa
 | 
			
		||||
 | 
			
		||||
from ipex_llm.transformers.models.common import merge_qkv_base
 | 
			
		||||
from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
 | 
			
		||||
| 
						 | 
				
			
			@ -465,30 +466,10 @@ def qwen2_causal_lm_forward(
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def merge_qkv(module: torch.nn.Module):
 | 
			
		||||
    if isinstance(module, Qwen2Attention):
 | 
			
		||||
        new_weight = torch.cat([
 | 
			
		||||
            module.q_proj.weight.data,
 | 
			
		||||
            module.k_proj.weight.data,
 | 
			
		||||
            module.v_proj.weight.data,
 | 
			
		||||
        ], dim=0)
 | 
			
		||||
        new_bias = torch.cat([
 | 
			
		||||
            module.q_proj.bias.data,
 | 
			
		||||
            module.k_proj.bias.data,
 | 
			
		||||
            module.v_proj.bias.data,
 | 
			
		||||
        ], dim=-1)
 | 
			
		||||
 | 
			
		||||
        qkv_proj = torch.nn.Linear(0, 0, bias=True)
 | 
			
		||||
        qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False)
 | 
			
		||||
        qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False)
 | 
			
		||||
        qkv_proj.in_features = new_weight.size(1)
 | 
			
		||||
        qkv_proj.out_features = new_weight.size(0)
 | 
			
		||||
        module.qkv_proj = qkv_proj
 | 
			
		||||
 | 
			
		||||
        del module.q_proj, module.k_proj, module.v_proj
 | 
			
		||||
 | 
			
		||||
        if os.environ.get("IPEX_LLM_LOW_MEM", None) == "1":
 | 
			
		||||
            del module.rotary_emb.cos_cached
 | 
			
		||||
            del module.rotary_emb.sin_cached
 | 
			
		||||
    merge_qkv_base(module, Qwen2Attention)
 | 
			
		||||
    if isinstance(module, Qwen2Attention) and os.environ.get("IPEX_LLM_LOW_MEM", None) == "1":
 | 
			
		||||
        del module.rotary_emb.cos_cached
 | 
			
		||||
        del module.rotary_emb.sin_cached
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def padding_mlp(module: torch.nn.Module):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue