parent
cfdf8ad496
commit
021d77fd22
1 changed files with 6 additions and 9 deletions
|
|
@ -995,9 +995,8 @@ def llama_attention_forward_4_36_quantized(
|
|||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
# at inference time, for memory considerations, may not need to upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
if use_cache:
|
||||
cache_kwargs = None
|
||||
|
|
@ -1036,9 +1035,8 @@ def llama_attention_forward_4_36_quantized(
|
|||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights,
|
||||
dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
# at inference time, for memory considerations, may not need to upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
|
@ -1324,9 +1322,8 @@ def native_sdp(query, key, value, attention_mask,
|
|||
f"but is {attention_mask.size()}")
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(value.dtype)
|
||||
# at inference time, for memory considerations, may not need to upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue