LLM: add bs limitation for llama softmax upcast to fp32 (#10752)
This commit is contained in:
parent
0d518aab8d
commit
c3fc8f4b90
1 changed files with 9 additions and 6 deletions
|
|
@ -1034,8 +1034,9 @@ def llama_attention_forward_4_36_quantized(
|
|||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if kv_seq_len >= 2048:
|
||||
# for memory considerations, do not upcast attention to fp32 for long sequences
|
||||
if kv_seq_len >= 2048 or bsz >= 64:
|
||||
# for memory considerations, do not upcast attention to fp32
|
||||
# for long sequences or large batches
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
else:
|
||||
# upcast attention to fp32
|
||||
|
|
@ -1079,8 +1080,9 @@ def llama_attention_forward_4_36_quantized(
|
|||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if kv_seq_len >= 2048:
|
||||
# for memory considerations, do not upcast attention to fp32 for long sequences
|
||||
if kv_seq_len >= 2048 or bsz >= 64:
|
||||
# for memory considerations, do not upcast attention to fp32
|
||||
# for long sequences or large batches
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
else:
|
||||
# upcast attention to fp32
|
||||
|
|
@ -1379,8 +1381,9 @@ def native_sdp(query, key, value, attention_mask,
|
|||
f"but is {attention_mask.size()}")
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if kv_seq_len >= 2048:
|
||||
# for memory considerations, do not upcast attention to fp32 for long sequences
|
||||
if kv_seq_len >= 2048 or bsz >= 64:
|
||||
# for memory considerations, do not upcast attention to fp32
|
||||
# for long sequences or large batches
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
else:
|
||||
# upcast attention to fp32
|
||||
|
|
|
|||
Loading…
Reference in a new issue