LLM: update esimd sdp kernel (#9871)
This commit is contained in:
		
							parent
							
								
									023679459e
								
							
						
					
					
						commit
						3e05c9e11b
					
				
					 1 changed files with 2 additions and 2 deletions
				
			
		| 
						 | 
					@ -276,8 +276,8 @@ def llama_attention_forward_4_31(
 | 
				
			||||||
    elif use_esimd_sdp(q_len, self.head_dim, query_states):
 | 
					    elif use_esimd_sdp(q_len, self.head_dim, query_states):
 | 
				
			||||||
        import linear_fp16_esimd
 | 
					        import linear_fp16_esimd
 | 
				
			||||||
        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
					        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
				
			||||||
                                                    key_states.contiguous(),
 | 
					                                                    key_states,
 | 
				
			||||||
                                                    value_states.contiguous())
 | 
					                                                    value_states)
 | 
				
			||||||
        attn_output = attn_output.view(query_states.shape)
 | 
					        attn_output = attn_output.view(query_states.shape)
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue