parent
							
								
									3e601f9a5d
								
							
						
					
					
						commit
						7b1d9ad7c0
					
				
					 2 changed files with 8 additions and 5 deletions
				
			
		| 
						 | 
				
			
			@ -284,7 +284,7 @@ def llama_attention_forward_4_31(
 | 
			
		|||
                                                     value_states,
 | 
			
		||||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif use_esimd_sdp(q_len, self.head_dim, query_states):
 | 
			
		||||
    elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
			
		||||
                                                    key_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -689,11 +689,11 @@ def llama_attention_forward_4_36(
 | 
			
		|||
                                                     value_states,
 | 
			
		||||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif use_esimd_sdp(q_len, self.head_dim, query_states):
 | 
			
		||||
    elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
			
		||||
                                                    key_states.contiguous(),
 | 
			
		||||
                                                    value_states.contiguous())
 | 
			
		||||
                                                    key_states,
 | 
			
		||||
                                                    value_states)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -230,13 +230,16 @@ def use_flash_attention(query, key):
 | 
			
		|||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_esimd_sdp(q_len, head_dim, query_states):
 | 
			
		||||
def use_esimd_sdp(q_len, k_len, head_dim, query_states):
 | 
			
		||||
    if head_dim != 128:
 | 
			
		||||
        # esimd_sdp only support head_dim = 128 now
 | 
			
		||||
        return False
 | 
			
		||||
    elif q_len != 1:
 | 
			
		||||
        # esimd_sdp only support rest token now
 | 
			
		||||
        return False
 | 
			
		||||
    elif k_len < 8:
 | 
			
		||||
        # esimd_sdp will cause wrong output when k_len < 8
 | 
			
		||||
        return False
 | 
			
		||||
    elif query_states.device.type != "xpu":
 | 
			
		||||
        # esimd_sdp only support GPU now
 | 
			
		||||
        return False
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue