parent
cfdf8ad496
commit
021d77fd22
1 changed files with 6 additions and 9 deletions
|
|
@ -995,9 +995,8 @@ def llama_attention_forward_4_36_quantized(
|
||||||
)
|
)
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
# upcast attention to fp32
|
# at inference time, for memory considerations, may not need to upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
dtype=torch.float32).to(query_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
cache_kwargs = None
|
cache_kwargs = None
|
||||||
|
|
@ -1036,9 +1035,8 @@ def llama_attention_forward_4_36_quantized(
|
||||||
)
|
)
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
# upcast attention to fp32
|
# at inference time, for memory considerations, may not need to upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights,
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
||||||
|
|
||||||
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
@ -1324,9 +1322,8 @@ def native_sdp(query, key, value, attention_mask,
|
||||||
f"but is {attention_mask.size()}")
|
f"but is {attention_mask.size()}")
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
# upcast attention to fp32
|
# at inference time, for memory considerations, may not need to upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
dtype=torch.float32).to(value.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value)
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue