[NPU] Qwen prefill attn_mask type hotfix (#12395)
* qwen prefill attn_mask type fp16 * update
This commit is contained in:
		
							parent
							
								
									9220babaab
								
							
						
					
					
						commit
						d6d63d6b84
					
				
					 1 changed files with 2 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -247,7 +247,8 @@ class LLMBaseNNFactory(NNFactory):
 | 
			
		|||
            attn_weight = self.matmul(query_states, key_states, False, True) / (
 | 
			
		||||
                math.sqrt(head_dim)
 | 
			
		||||
            )
 | 
			
		||||
            attention_mask = self.convert_to_fp16(attention_mask)
 | 
			
		||||
            if mode != "prefill":
 | 
			
		||||
                attention_mask = self.convert_to_fp16(attention_mask)
 | 
			
		||||
            attn_weight = self.eltwise_add(attn_weight, attention_mask)
 | 
			
		||||
            attn_weight = self.convert_to_fp32(attn_weight)
 | 
			
		||||
            attn_weight = self.softmax(attn_weight, -1)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue