diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index f077474f..e6a787b6 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -131,11 +131,11 @@ def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_ ) attn_weights = attn_weights + attention_mask - - if kv_seq_len >= 2048 or bsz >= 64: - # for memory considerations, do not upcast attention to fp32 - # for long sequences or large batches - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + if os.getenv("IPEX_LLM_LOW_MEM", '0').lower() in ('true', '1', 't'): + if kv_seq_len >= 2048 or bsz >= 64: + # for memory considerations, do not upcast attention to fp32 + # for long sequences or large batches + attn_weights = nn.functional.softmax(attn_weights, dim=-1) else: # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, @@ -838,7 +838,6 @@ def mistral_attention_forward_4_36_quantized( f" but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask - if kv_seq_len >= 2048 or bsz >= 64: # for memory considerations, do not upcast attention to fp32 # for long sequences or large batches