LLM: add bs limitation for llama softmax upcast to fp32 (#10752)

This commit is contained in:
binbin Deng 2024-04-12 15:40:25 +08:00 committed by GitHub
parent 0d518aab8d
commit c3fc8f4b90
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1034,8 +1034,9 @@ def llama_attention_forward_4_36_quantized(
)
attn_weights = attn_weights + attention_mask
if kv_seq_len >= 2048:
# for memory considerations, do not upcast attention to fp32 for long sequences
if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32
# for long sequences or large batches
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
else:
# upcast attention to fp32
@ -1079,8 +1080,9 @@ def llama_attention_forward_4_36_quantized(
)
attn_weights = attn_weights + attention_mask
if kv_seq_len >= 2048:
# for memory considerations, do not upcast attention to fp32 for long sequences
if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32
# for long sequences or large batches
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
else:
# upcast attention to fp32
@ -1379,8 +1381,9 @@ def native_sdp(query, key, value, attention_mask,
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
if kv_seq_len >= 2048:
# for memory considerations, do not upcast attention to fp32 for long sequences
if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32
# for long sequences or large batches
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
else:
# upcast attention to fp32