diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 99bddf04..6649c180 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1034,8 +1034,9 @@ def llama_attention_forward_4_36_quantized( ) attn_weights = attn_weights + attention_mask - if kv_seq_len >= 2048: - # for memory considerations, do not upcast attention to fp32 for long sequences + if kv_seq_len >= 2048 or bsz >= 64: + # for memory considerations, do not upcast attention to fp32 + # for long sequences or large batches attn_weights = nn.functional.softmax(attn_weights, dim=-1) else: # upcast attention to fp32 @@ -1079,8 +1080,9 @@ def llama_attention_forward_4_36_quantized( ) attn_weights = attn_weights + attention_mask - if kv_seq_len >= 2048: - # for memory considerations, do not upcast attention to fp32 for long sequences + if kv_seq_len >= 2048 or bsz >= 64: + # for memory considerations, do not upcast attention to fp32 + # for long sequences or large batches attn_weights = nn.functional.softmax(attn_weights, dim=-1) else: # upcast attention to fp32 @@ -1379,8 +1381,9 @@ def native_sdp(query, key, value, attention_mask, f"but is {attention_mask.size()}") attn_weights = attn_weights + attention_mask - if kv_seq_len >= 2048: - # for memory considerations, do not upcast attention to fp32 for long sequences + if kv_seq_len >= 2048 or bsz >= 64: + # for memory considerations, do not upcast attention to fp32 + # for long sequences or large batches attn_weights = nn.functional.softmax(attn_weights, dim=-1) else: # upcast attention to fp32