From c3fc8f4b90e92f1dbc7ed907cc8e8509e43a91a8 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Fri, 12 Apr 2024 15:40:25 +0800 Subject: [PATCH] LLM: add bs limitation for llama softmax upcast to fp32 (#10752) --- .../llm/src/ipex_llm/transformers/models/llama.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 99bddf04..6649c180 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1034,8 +1034,9 @@ def llama_attention_forward_4_36_quantized( ) attn_weights = attn_weights + attention_mask - if kv_seq_len >= 2048: - # for memory considerations, do not upcast attention to fp32 for long sequences + if kv_seq_len >= 2048 or bsz >= 64: + # for memory considerations, do not upcast attention to fp32 + # for long sequences or large batches attn_weights = nn.functional.softmax(attn_weights, dim=-1) else: # upcast attention to fp32 @@ -1079,8 +1080,9 @@ def llama_attention_forward_4_36_quantized( ) attn_weights = attn_weights + attention_mask - if kv_seq_len >= 2048: - # for memory considerations, do not upcast attention to fp32 for long sequences + if kv_seq_len >= 2048 or bsz >= 64: + # for memory considerations, do not upcast attention to fp32 + # for long sequences or large batches attn_weights = nn.functional.softmax(attn_weights, dim=-1) else: # upcast attention to fp32 @@ -1379,8 +1381,9 @@ def native_sdp(query, key, value, attention_mask, f"but is {attention_mask.size()}") attn_weights = attn_weights + attention_mask - if kv_seq_len >= 2048: - # for memory considerations, do not upcast attention to fp32 for long sequences + if kv_seq_len >= 2048 or bsz >= 64: + # for memory considerations, do not upcast attention to fp32 + # for long sequences or large batches attn_weights = nn.functional.softmax(attn_weights, dim=-1) else: # upcast attention to fp32