[NPU] Update layernorm node on MTL/ARL (#12738)

* Update layernorm node on MTL/ARL

* Fix on style
This commit is contained in:
Yuwen Hu 2025-01-23 17:25:19 +08:00 committed by GitHub
parent d11f257ee7
commit 69f13c78b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -472,7 +472,9 @@ class LLMBaseNNFactory(NNFactory):
)
eps = self.constant(self.rms_norm_eps)
hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps)))
if os.environ.get("IPEX_LLM_NPU_DRIVER_VERSION", None) in ["5716", "5733"]:
if os.environ.get("IPEX_LLM_NPU_DRIVER_VERSION", None) in ["5716", "5733"] or \
os.environ.get("IPEX_LLM_NPU_MTL", "0") == "1" or \
os.environ.get("IPEX_LLM_NPU_ARL", "0") == "1":
# to support special drivers
hidden_states = self.convert_to_fp16(hidden_states)
else: