disable sdp_causal in phi-3 to fix overflow (#11157)
This commit is contained in:
		
							parent
							
								
									33852bd23e
								
							
						
					
					
						commit
						bc5008f0d5
					
				
					 1 changed files with 9 additions and 8 deletions
				
			
		| 
						 | 
					@ -139,14 +139,15 @@ def attention_forward(
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
					            attn_output = xe_addons.sdp(query_states, key_states, value_states,
 | 
				
			||||||
                                        attention_mask)
 | 
					                                        attention_mask)
 | 
				
			||||||
    elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
					    # disable sdp_causal to avoid overflow for now
 | 
				
			||||||
        import xe_addons
 | 
					    # elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
 | 
				
			||||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
					    #     import xe_addons
 | 
				
			||||||
            attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
					    #     if isinstance(past_key_value, DynamicFp8Cache):
 | 
				
			||||||
                                                   value_states, attention_mask)
 | 
					    #         attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
 | 
				
			||||||
        else:
 | 
					    #                                                value_states, attention_mask)
 | 
				
			||||||
            attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
					    #     else:
 | 
				
			||||||
                                               value_states, attention_mask)
 | 
					    #         attn_output = xe_addons.sdp_causal(query_states, key_states,
 | 
				
			||||||
 | 
					    #                                            value_states, attention_mask)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        if isinstance(past_key_value, DynamicFp8Cache):
 | 
					        if isinstance(past_key_value, DynamicFp8Cache):
 | 
				
			||||||
            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
					            key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue