Reduce Mistral softmax memory only in low memory mode (#11775)
* Reduce Mistral softmax memory only in low memory mode
This commit is contained in:
parent
aa861df066
commit
a88c132e54
1 changed files with 5 additions and 6 deletions
|
|
@ -131,11 +131,11 @@ def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_
|
|||
)
|
||||
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if kv_seq_len >= 2048 or bsz >= 64:
|
||||
# for memory considerations, do not upcast attention to fp32
|
||||
# for long sequences or large batches
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
if os.getenv("IPEX_LLM_LOW_MEM", '0').lower() in ('true', '1', 't'):
|
||||
if kv_seq_len >= 2048 or bsz >= 64:
|
||||
# for memory considerations, do not upcast attention to fp32
|
||||
# for long sequences or large batches
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
else:
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
|
|
@ -838,7 +838,6 @@ def mistral_attention_forward_4_36_quantized(
|
|||
f" but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if kv_seq_len >= 2048 or bsz >= 64:
|
||||
# for memory considerations, do not upcast attention to fp32
|
||||
# for long sequences or large batches
|
||||
|
|
|
|||
Loading…
Reference in a new issue