[LLM] Add rope optimization for internlm (#9159)
* add rope and norm optimization for internlm and gptneox * revert gptneox back and split with pr#9155 # * add norm_forward * style fix * update * update
This commit is contained in:
		
							parent
							
								
									f754ab3e60
								
							
						
					
					
						commit
						e7aa67e141
					
				
					 2 changed files with 18 additions and 9 deletions
				
			
		| 
						 | 
				
			
			@ -311,6 +311,10 @@ def optimize(model):
 | 
			
		|||
                        module.InternLMAttention,
 | 
			
		||||
                        internlm_attention_forward
 | 
			
		||||
                        )
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.InternLMRMSNorm,
 | 
			
		||||
                        llama_rms_norm_forward
 | 
			
		||||
                        )
 | 
			
		||||
    elif model.config.model_type == "qwen":
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -74,15 +74,20 @@ def internlm_attention_forward(
 | 
			
		|||
    kv_seq_len = key_states.shape[-2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
			
		||||
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
    query_states, key_states = apply_rotary_pos_emb(
 | 
			
		||||
        query_states,
 | 
			
		||||
        key_states,
 | 
			
		||||
        cos,
 | 
			
		||||
        sin,
 | 
			
		||||
        position_ids,
 | 
			
		||||
        "internlm"
 | 
			
		||||
    )
 | 
			
		||||
    if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad):
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
 | 
			
		||||
                                                                     key_states,
 | 
			
		||||
                                                                     position_ids,
 | 
			
		||||
                                                                     "internlm")
 | 
			
		||||
    else:
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
        query_states, key_states = apply_rotary_pos_emb(
 | 
			
		||||
            query_states,
 | 
			
		||||
            key_states,
 | 
			
		||||
            cos,
 | 
			
		||||
            sin,
 | 
			
		||||
            position_ids,
 | 
			
		||||
            "internlm")
 | 
			
		||||
    # [bsz, nh, t, hd]
 | 
			
		||||
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue