From a88c132e546144963e1cb8e0c9c0cf3f1b6f71e0 Mon Sep 17 00:00:00 2001 From: Qiyuan Gong Date: Tue, 13 Aug 2024 14:50:54 +0800 Subject: [PATCH] Reduce Mistral softmax memory only in low memory mode (#11775) * Reduce Mistral softmax memory only in low memory mode --- .../llm/src/ipex_llm/transformers/models/mistral.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index f077474f..e6a787b6 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -131,11 +131,11 @@ def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_ ) attn_weights = attn_weights + attention_mask - - if kv_seq_len >= 2048 or bsz >= 64: - # for memory considerations, do not upcast attention to fp32 - # for long sequences or large batches - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + if os.getenv("IPEX_LLM_LOW_MEM", '0').lower() in ('true', '1', 't'): + if kv_seq_len >= 2048 or bsz >= 64: + # for memory considerations, do not upcast attention to fp32 + # for long sequences or large batches + attn_weights = nn.functional.softmax(attn_weights, dim=-1) else: # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, @@ -838,7 +838,6 @@ def mistral_attention_forward_4_36_quantized( f" but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask - if kv_seq_len >= 2048 or bsz >= 64: # for memory considerations, do not upcast attention to fp32 # for long sequences or large batches