[LLM] Reimplement chatglm fuse rms optimization (#9260)
* re-implement chatglm rope rms * update
This commit is contained in:
		
							parent
							
								
									5a2ce421af
								
							
						
					
					
						commit
						bd5215d75b
					
				
					 2 changed files with 15 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -226,6 +226,7 @@ def _optimize_post(model):
 | 
			
		|||
            module = importlib.import_module(modeling_module_name)
 | 
			
		||||
            from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
 | 
			
		||||
            from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
 | 
			
		||||
            from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.SelfAttention,
 | 
			
		||||
                            chatglm2_attention_forward_8eb45c
 | 
			
		||||
| 
						 | 
				
			
			@ -235,7 +236,7 @@ def _optimize_post(model):
 | 
			
		|||
                            core_attn_forward_8eb45c)
 | 
			
		||||
            convert_forward(model,
 | 
			
		||||
                            module.RMSNorm,
 | 
			
		||||
                            llama_rms_norm_forward)
 | 
			
		||||
                            chatglm_rms_norm_forward)
 | 
			
		||||
        elif hasattr(model.config, 'vocab_size') and model.config.vocab_size == 130528:
 | 
			
		||||
            # chatglm-6b
 | 
			
		||||
            modeling_module_name = model.__class__.__module__
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -74,6 +74,19 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
 | 
			
		|||
    return torch.cat((x_out2, x_pass), dim=-1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
 | 
			
		||||
                                                         [self.weight.size(0)], self.weight)
 | 
			
		||||
    else:
 | 
			
		||||
        input_dtype = hidden_states.dtype
 | 
			
		||||
        hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
 | 
			
		||||
        return self.weight * hidden_states.to(input_dtype)
 | 
			
		||||
    return hidden_states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm2_attention_forward_8eb45c(
 | 
			
		||||
        self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
 | 
			
		||||
):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue