parent
							
								
									b785376f5c
								
							
						
					
					
						commit
						69c49d21f5
					
				
					 1 changed files with 8 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -74,22 +74,26 @@ def get_ipex_version():
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def llama_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    optimized_rms_norm = False
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        if get_ipex_version() <= "2.0.110+xpu":
 | 
			
		||||
            if self.variance_epsilon == 1e-6:
 | 
			
		||||
                hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
 | 
			
		||||
                                                                 [self.weight.size(0)],
 | 
			
		||||
                                                                 self.weight)
 | 
			
		||||
                optimized_rms_norm = True
 | 
			
		||||
            else:
 | 
			
		||||
                import linear_q4_0
 | 
			
		||||
                hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
 | 
			
		||||
                                                           [self.weight.size(0)],
 | 
			
		||||
                                                           self.weight,
 | 
			
		||||
                                                           None,
 | 
			
		||||
                                                           self.variance_epsilon)
 | 
			
		||||
        else:
 | 
			
		||||
            hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
 | 
			
		||||
                                                               [self.weight.size(0)],
 | 
			
		||||
                                                               self.weight,
 | 
			
		||||
                                                               None,
 | 
			
		||||
                                                               self.variance_epsilon)
 | 
			
		||||
            optimized_rms_norm = True
 | 
			
		||||
    if not optimized_rms_norm:
 | 
			
		||||
    else:
 | 
			
		||||
        input_dtype = hidden_states.dtype
 | 
			
		||||
        hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue