Remove softmax upcast fp32 in llama (#10481)

* update

* fix style
This commit is contained in:
Kai Huang 2024-03-20 18:17:34 +08:00 committed by GitHub
parent cfdf8ad496
commit 021d77fd22

View file

@ -995,9 +995,8 @@ def llama_attention_forward_4_36_quantized(
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
# at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value_states)
if use_cache:
cache_kwargs = None
@ -1036,9 +1035,8 @@ def llama_attention_forward_4_36_quantized(
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1, dtype=torch.float32).to(query_states.dtype)
# at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
attn_output = torch.matmul(attn_weights, value_states)
@ -1324,9 +1322,8 @@ def native_sdp(query, key, value, attention_mask,
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(value.dtype)
# at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights