[NPU] Qwen prefill attn_mask type hotfix (#12395)
* qwen prefill attn_mask type fp16 * update
This commit is contained in:
parent
9220babaab
commit
d6d63d6b84
1 changed files with 2 additions and 1 deletions
|
|
@ -247,6 +247,7 @@ class LLMBaseNNFactory(NNFactory):
|
|||
attn_weight = self.matmul(query_states, key_states, False, True) / (
|
||||
math.sqrt(head_dim)
|
||||
)
|
||||
if mode != "prefill":
|
||||
attention_mask = self.convert_to_fp16(attention_mask)
|
||||
attn_weight = self.eltwise_add(attn_weight, attention_mask)
|
||||
attn_weight = self.convert_to_fp32(attn_weight)
|
||||
|
|
|
|||
Loading…
Reference in a new issue