[NPU] Update layernorm node on MTL/ARL (#12738)
* Update layernorm node on MTL/ARL * Fix on style
This commit is contained in:
		
							parent
							
								
									d11f257ee7
								
							
						
					
					
						commit
						69f13c78b8
					
				
					 1 changed files with 3 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -472,7 +472,9 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
        )
 | 
			
		||||
        eps = self.constant(self.rms_norm_eps)
 | 
			
		||||
        hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps)))
 | 
			
		||||
        if os.environ.get("IPEX_LLM_NPU_DRIVER_VERSION", None) in ["5716", "5733"]:
 | 
			
		||||
        if os.environ.get("IPEX_LLM_NPU_DRIVER_VERSION", None) in ["5716", "5733"] or \
 | 
			
		||||
           os.environ.get("IPEX_LLM_NPU_MTL", "0") == "1" or \
 | 
			
		||||
           os.environ.get("IPEX_LLM_NPU_ARL", "0") == "1":
 | 
			
		||||
            # to support special drivers
 | 
			
		||||
            hidden_states = self.convert_to_fp16(hidden_states)
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue