optimize npu llama2 first token performance (#11451)
This commit is contained in:
		
							parent
							
								
									4e4ecd5095
								
							
						
					
					
						commit
						029ff15d28
					
				
					 1 changed files with 19 additions and 7 deletions
				
			
		| 
						 | 
				
			
			@ -219,13 +219,25 @@ def llama_attention_forward(
 | 
			
		|||
    else:
 | 
			
		||||
        causal_mask = None
 | 
			
		||||
 | 
			
		||||
    attn_output = torch.nn.functional.scaled_dot_product_attention(
 | 
			
		||||
        query_states,
 | 
			
		||||
        key_states,
 | 
			
		||||
        value_states,
 | 
			
		||||
        attn_mask=causal_mask,
 | 
			
		||||
        is_causal=self.is_causal and attention_mask is None and q_len > 1,
 | 
			
		||||
    )
 | 
			
		||||
    if query_states.size(2) == key_states.size(2):
 | 
			
		||||
        # first token
 | 
			
		||||
        from intel_npu_acceleration_library.functional import scaled_dot_product_attention
 | 
			
		||||
        attn_output = scaled_dot_product_attention(
 | 
			
		||||
            query_states,
 | 
			
		||||
            key_states,
 | 
			
		||||
            value_states,
 | 
			
		||||
            attn_mask=causal_mask,
 | 
			
		||||
            is_causal=self.is_causal and causal_mask is None and q_len > 1,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        # second+ token
 | 
			
		||||
        attn_output = torch.nn.functional.scaled_dot_product_attention(
 | 
			
		||||
            query_states,
 | 
			
		||||
            key_states,
 | 
			
		||||
            value_states,
 | 
			
		||||
            attn_mask=causal_mask,
 | 
			
		||||
            is_causal=self.is_causal and causal_mask is None and q_len > 1,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue