From 0fbb10259a44564dd9a096627eb1c80af636bcb3 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 28 Aug 2024 17:35:05 +0800 Subject: [PATCH] use sdp_causal to reduce internvl2-4b memory usage if set environment variable (#11953) --- .../src/ipex_llm/transformers/models/phi3.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 0b7f873e..fa6c43d6 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -160,15 +160,17 @@ def attention_forward( else: attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) - # disable sdp_causal to avoid overflow for now - # elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - # import xe_addons - # if isinstance(past_key_value, DynamicFp8Cache): - # attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - # value_states, attention_mask) - # else: - # attn_output = xe_addons.sdp_causal(query_states, key_states, - # value_states, attention_mask) + elif ( + use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training) + and os.environ.get("IPEX_LLM_LOW_MEM", "0") == "1" + ): + import xe_addons + if isinstance(past_key_value, DynamicFp8Cache): + attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, + value_states, attention_mask) + else: + attn_output = xe_addons.sdp_causal(query_states, key_states, + value_states, attention_mask) else: if use_quantizekv: key_states, value_states = restore_fp8_kv_cache(key_states, value_states,