Fix wrong output for Llama models on CPU (#10742)
This commit is contained in:
		
							parent
							
								
									e764f9b1b1
								
							
						
					
					
						commit
						31ea2f9a9f
					
				
					 1 changed files with 16 additions and 4 deletions
				
			
		| 
						 | 
					@ -1335,10 +1335,22 @@ def llama_attention_forward_4_36_original(
 | 
				
			||||||
        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
					        key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
				
			||||||
        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
					        value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
				
			||||||
        # otherwise, use native attention
 | 
					        # otherwise, use native attention
 | 
				
			||||||
        attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
 | 
					        if not output_attentions:
 | 
				
			||||||
                                               attention_mask,
 | 
					            attn_output = torch.nn.functional.scaled_dot_product_attention(
 | 
				
			||||||
                                               bsz, q_len, kv_seq_len,
 | 
					                query_states,
 | 
				
			||||||
                                               self.head_dim, self.num_heads, output_attentions)
 | 
					                key_states,
 | 
				
			||||||
 | 
					                value_states,
 | 
				
			||||||
 | 
					                attn_mask=attention_mask,
 | 
				
			||||||
 | 
					                dropout_p=self.attention_dropout if self.training else 0.0,
 | 
				
			||||||
 | 
					                # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that
 | 
				
			||||||
 | 
					                # does not create a causal mask in case q_len == 1.
 | 
				
			||||||
 | 
					                is_causal=self.is_causal and attention_mask is None and q_len > 1,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
 | 
				
			||||||
 | 
					                                                   attention_mask,
 | 
				
			||||||
 | 
					                                                   bsz, q_len, kv_seq_len,
 | 
				
			||||||
 | 
					                                                   self.head_dim, self.num_heads, output_attentions)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
 | 
					    attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
 | 
				
			||||||
    if attn_output.size() != attn_output_size:
 | 
					    if attn_output.size() != attn_output_size:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue