LLM: add bs limitation for llama softmax upcast to fp32 (#10752)
This commit is contained in:
parent
0d518aab8d
commit
c3fc8f4b90
1 changed files with 9 additions and 6 deletions
|
|
@ -1034,8 +1034,9 @@ def llama_attention_forward_4_36_quantized(
|
||||||
)
|
)
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
if kv_seq_len >= 2048:
|
if kv_seq_len >= 2048 or bsz >= 64:
|
||||||
# for memory considerations, do not upcast attention to fp32 for long sequences
|
# for memory considerations, do not upcast attention to fp32
|
||||||
|
# for long sequences or large batches
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
else:
|
else:
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
|
|
@ -1079,8 +1080,9 @@ def llama_attention_forward_4_36_quantized(
|
||||||
)
|
)
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
if kv_seq_len >= 2048:
|
if kv_seq_len >= 2048 or bsz >= 64:
|
||||||
# for memory considerations, do not upcast attention to fp32 for long sequences
|
# for memory considerations, do not upcast attention to fp32
|
||||||
|
# for long sequences or large batches
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
else:
|
else:
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
|
|
@ -1379,8 +1381,9 @@ def native_sdp(query, key, value, attention_mask,
|
||||||
f"but is {attention_mask.size()}")
|
f"but is {attention_mask.size()}")
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
if kv_seq_len >= 2048:
|
if kv_seq_len >= 2048 or bsz >= 64:
|
||||||
# for memory considerations, do not upcast attention to fp32 for long sequences
|
# for memory considerations, do not upcast attention to fp32
|
||||||
|
# for long sequences or large batches
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
else:
|
else:
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue