Reduce Mistral softmax memory only in low memory mode (#11775)

* Reduce Mistral softmax memory only in low memory mode
This commit is contained in:
Qiyuan Gong 2024-08-13 14:50:54 +08:00 committed by GitHub
parent aa861df066
commit a88c132e54
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -131,11 +131,11 @@ def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_
)
attn_weights = attn_weights + attention_mask
if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32
# for long sequences or large batches
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if os.getenv("IPEX_LLM_LOW_MEM", '0').lower() in ('true', '1', 't'):
if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32
# for long sequences or large batches
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
else:
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
@ -838,7 +838,6 @@ def mistral_attention_forward_4_36_quantized(
f" but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32
# for long sequences or large batches