[NPU] Qwen prefill attn_mask type hotfix (#12395)

* qwen prefill attn_mask type fp16

* update
This commit is contained in:
Yina Chen 2024-11-13 11:51:34 +02:00 committed by GitHub
parent 9220babaab
commit d6d63d6b84
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -247,7 +247,8 @@ class LLMBaseNNFactory(NNFactory):
attn_weight = self.matmul(query_states, key_states, False, True) / (
math.sqrt(head_dim)
)
attention_mask = self.convert_to_fp16(attention_mask)
if mode != "prefill":
attention_mask = self.convert_to_fp16(attention_mask)
attn_weight = self.eltwise_add(attn_weight, attention_mask)
attn_weight = self.convert_to_fp32(attn_weight)
attn_weight = self.softmax(attn_weight, -1)