From bc5008f0d5a54e08ce8c6483614f0db31764c017 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 28 May 2024 17:25:53 +0800 Subject: [PATCH] disable sdp_causal in phi-3 to fix overflow (#11157) --- .../src/ipex_llm/transformers/models/phi3.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 0dc3ee76..c398ae61 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -139,14 +139,15 @@ def attention_forward( else: attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, attention_mask) + # disable sdp_causal to avoid overflow for now + # elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): + # import xe_addons + # if isinstance(past_key_value, DynamicFp8Cache): + # attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, + # value_states, attention_mask) + # else: + # attn_output = xe_addons.sdp_causal(query_states, key_states, + # value_states, attention_mask) else: if isinstance(past_key_value, DynamicFp8Cache): key_states, value_states = restore_fp8_kv_cache(key_states, value_states,