optimize internlm xcomposser performance again (#11551)
This commit is contained in:
		
							parent
							
								
									61613b210c
								
							
						
					
					
						commit
						994e49a510
					
				
					 2 changed files with 7 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -1261,6 +1261,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward)
 | 
			
		||||
        from ipex_llm.transformers.models.internlm import internlm_xcomposser2_mlp_forward
 | 
			
		||||
        convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward)
 | 
			
		||||
        convert_forward(model, module.InternLM2RMSNorm, llama_rms_norm_forward)
 | 
			
		||||
        from ipex_llm.transformers.models.internlm import internlm_xcomposser2_chat
 | 
			
		||||
        model.chat = MethodType(internlm_xcomposser2_chat, model)
 | 
			
		||||
    elif model.config.model_type == "qwen":
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -359,12 +359,14 @@ def internlm_xcomposser2_attention_forward(
 | 
			
		|||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: fuse rope
 | 
			
		||||
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
    if should_use_fuse_rope(hidden_states, position_ids, self.training):
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(
 | 
			
		||||
            query_states, key_states, sin, cos, "internlm", position_ids
 | 
			
		||||
        )
 | 
			
		||||
        # This fuse rope will get wrong result if context_length > max_position_embeddings (32768)
 | 
			
		||||
        # we assume context_length <= 32768
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
 | 
			
		||||
                                       query_states, key_states)
 | 
			
		||||
    else:
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb(
 | 
			
		||||
            query_states, key_states, cos, sin, position_ids, "internlm")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue