add sdp causal support in llama (#11705)
This commit is contained in:
		
							parent
							
								
									736a7ef72e
								
							
						
					
					
						commit
						8d1e0bd2f4
					
				
					 1 changed files with 14 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -1497,6 +1497,13 @@ def llama_attention_forward_4_41_original(
 | 
			
		|||
                                                     value_states.to(device, dtype=torch.float16),
 | 
			
		||||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim,
 | 
			
		||||
                        query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(),
 | 
			
		||||
                                           value_states.contiguous(), new_attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
| 
						 | 
				
			
			@ -2040,6 +2047,13 @@ def llama_attention_forward_4_38_original(
 | 
			
		|||
                                                     value_states.to(device, dtype=torch.float16),
 | 
			
		||||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim,
 | 
			
		||||
                        query_states, self.training):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
        attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(),
 | 
			
		||||
                                           value_states.contiguous(), new_attention_mask)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import xe_addons
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue