From 021d77fd22f14baa99aa51f5c318006b50fc2e9c Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Wed, 20 Mar 2024 18:17:34 +0800 Subject: [PATCH] Remove softmax upcast fp32 in llama (#10481) * update * fix style --- .../src/bigdl/llm/transformers/models/llama.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index cf8c9896..448a0cf3 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -995,9 +995,8 @@ def llama_attention_forward_4_36_quantized( ) attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) + # at inference time, for memory considerations, may not need to upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, value_states) if use_cache: cache_kwargs = None @@ -1036,9 +1035,8 @@ def llama_attention_forward_4_36_quantized( ) attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, dtype=torch.float32).to(query_states.dtype) + # at inference time, for memory considerations, may not need to upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if query_states.size(2) != 1 or query_states.device.type != 'xpu': attn_output = torch.matmul(attn_weights, value_states) @@ -1324,9 +1322,8 @@ def native_sdp(query, key, value, attention_mask, f"but is {attention_mask.size()}") attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(value.dtype) + # at inference time, for memory considerations, may not need to upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights