[NPU] Fix abnormal output with latest driver (#12530)

This commit is contained in:
binbin Deng 2024-12-12 17:56:30 +08:00 committed by GitHub
parent ffce86d69f
commit f36c23664f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -471,7 +471,7 @@ class LLMBaseNNFactory(NNFactory):
)
eps = self.constant(self.rms_norm_eps)
hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps)))
layernorm_weight = self.convert_to_fp32(layernorm_weight)
hidden_states = self.convert_to_fp16(hidden_states)
hidden_states = self.eltwise_mul(layernorm_weight, hidden_states)
hidden_states = self.convert_to_fp16(hidden_states)
return hidden_states