Remove softmax upcast fp32 in llama (#10481)

* update

* fix style
This commit is contained in:
Kai Huang 2024-03-20 18:17:34 +08:00 committed by GitHub
parent cfdf8ad496
commit 021d77fd22

View file

@ -995,9 +995,8 @@ def llama_attention_forward_4_36_quantized(
) )
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, attn_weights = nn.functional.softmax(attn_weights, dim=-1)
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
if use_cache: if use_cache:
cache_kwargs = None cache_kwargs = None
@ -1036,9 +1035,8 @@ def llama_attention_forward_4_36_quantized(
) )
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, attn_weights = nn.functional.softmax(attn_weights, dim=-1)
dim=-1, dtype=torch.float32).to(query_states.dtype)
if query_states.size(2) != 1 or query_states.device.type != 'xpu': if query_states.size(2) != 1 or query_states.device.type != 'xpu':
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
@ -1324,9 +1322,8 @@ def native_sdp(query, key, value, attention_mask,
f"but is {attention_mask.size()}") f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # at inference time, for memory considerations, may not need to upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, attn_weights = nn.functional.softmax(attn_weights, dim=-1)
dtype=torch.float32).to(value.dtype)
attn_output = torch.matmul(attn_weights, value) attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights return attn_output, attn_weights