Optimize qwen 1.5 14B batch performance (#11370)
This commit is contained in:
		
							parent
							
								
									5aa3e427a9
								
							
						
					
					
						commit
						f0fdfa081b
					
				
					 2 changed files with 35 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -735,6 +735,8 @@ def _optimize_pre(model):
 | 
			
		|||
    if model.config.model_type == "qwen2":
 | 
			
		||||
        from ipex_llm.transformers.models.qwen2 import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
        from ipex_llm.transformers.models.qwen2 import padding_mlp
 | 
			
		||||
        model.apply(padding_mlp)
 | 
			
		||||
    if model.config.model_type == "qwen2_moe":
 | 
			
		||||
        from ipex_llm.transformers.models.qwen2_moe import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -49,7 +49,8 @@ from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use
 | 
			
		|||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
 | 
			
		||||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
| 
						 | 
				
			
			@ -288,6 +289,37 @@ def merge_qkv(module: torch.nn.Module):
 | 
			
		|||
        del module.q_proj, module.k_proj, module.v_proj
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def padding_mlp(module: torch.nn.Module):
 | 
			
		||||
    # for qwen 1.5 14B
 | 
			
		||||
    if isinstance(module, Qwen2MLP):
 | 
			
		||||
        hidden_size = module.hidden_size
 | 
			
		||||
        intermediate_size = module.intermediate_size
 | 
			
		||||
        padding_intermediate_size = (intermediate_size + 256 - 1) // 256 * 256
 | 
			
		||||
        if intermediate_size % 256 == 0:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        gate_weight = module.gate_proj.weight.data
 | 
			
		||||
        new_gate_weight = torch.zeros([padding_intermediate_size, hidden_size],
 | 
			
		||||
                                      dtype=gate_weight.dtype, device=gate_weight.device)
 | 
			
		||||
        new_gate_weight[:intermediate_size, :] = gate_weight
 | 
			
		||||
        module.gate_proj.out_features = padding_intermediate_size
 | 
			
		||||
        module.gate_proj.weight = torch.nn.Parameter(new_gate_weight, requires_grad=False)
 | 
			
		||||
 | 
			
		||||
        up_weight = module.up_proj.weight.data
 | 
			
		||||
        new_up_weight = torch.zeros([padding_intermediate_size, hidden_size],
 | 
			
		||||
                                    dtype=up_weight.dtype, device=up_weight.device)
 | 
			
		||||
        new_up_weight[:intermediate_size, :] = up_weight
 | 
			
		||||
        module.up_proj.out_features = padding_intermediate_size
 | 
			
		||||
        module.up_proj.weight = torch.nn.Parameter(new_up_weight, requires_grad=False)
 | 
			
		||||
 | 
			
		||||
        down_weight = module.down_proj.weight.data
 | 
			
		||||
        new_down_weight = torch.zeros([hidden_size, padding_intermediate_size],
 | 
			
		||||
                                      dtype=down_weight.dtype, device=down_weight.device)
 | 
			
		||||
        new_down_weight[:, :intermediate_size] = down_weight
 | 
			
		||||
        module.down_proj.in_features = padding_intermediate_size
 | 
			
		||||
        module.down_proj.weight = torch.nn.Parameter(new_down_weight, requires_grad=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen2_attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue