diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index cf8c9896..448a0cf3 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -995,9 +995,8 @@ def llama_attention_forward_4_36_quantized( ) attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) + # at inference time, for memory considerations, may not need to upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, value_states) if use_cache: cache_kwargs = None @@ -1036,9 +1035,8 @@ def llama_attention_forward_4_36_quantized( ) attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, dtype=torch.float32).to(query_states.dtype) + # at inference time, for memory considerations, may not need to upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if query_states.size(2) != 1 or query_states.device.type != 'xpu': attn_output = torch.matmul(attn_weights, value_states) @@ -1324,9 +1322,8 @@ def native_sdp(query, key, value, attention_mask, f"but is {attention_mask.size()}") attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(value.dtype) + # at inference time, for memory considerations, may not need to upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights