optimize npu llama2 first token performance (#11451)
This commit is contained in:
		
							parent
							
								
									4e4ecd5095
								
							
						
					
					
						commit
						029ff15d28
					
				
					 1 changed files with 19 additions and 7 deletions
				
			
		| 
						 | 
					@ -219,13 +219,25 @@ def llama_attention_forward(
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        causal_mask = None
 | 
					        causal_mask = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_output = torch.nn.functional.scaled_dot_product_attention(
 | 
					    if query_states.size(2) == key_states.size(2):
 | 
				
			||||||
        query_states,
 | 
					        # first token
 | 
				
			||||||
        key_states,
 | 
					        from intel_npu_acceleration_library.functional import scaled_dot_product_attention
 | 
				
			||||||
        value_states,
 | 
					        attn_output = scaled_dot_product_attention(
 | 
				
			||||||
        attn_mask=causal_mask,
 | 
					            query_states,
 | 
				
			||||||
        is_causal=self.is_causal and attention_mask is None and q_len > 1,
 | 
					            key_states,
 | 
				
			||||||
    )
 | 
					            value_states,
 | 
				
			||||||
 | 
					            attn_mask=causal_mask,
 | 
				
			||||||
 | 
					            is_causal=self.is_causal and causal_mask is None and q_len > 1,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        # second+ token
 | 
				
			||||||
 | 
					        attn_output = torch.nn.functional.scaled_dot_product_attention(
 | 
				
			||||||
 | 
					            query_states,
 | 
				
			||||||
 | 
					            key_states,
 | 
				
			||||||
 | 
					            value_states,
 | 
				
			||||||
 | 
					            attn_mask=causal_mask,
 | 
				
			||||||
 | 
					            is_causal=self.is_causal and causal_mask is None and q_len > 1,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
					    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue