fix sd1.5 (#12129)
This commit is contained in:
parent
a266528719
commit
669ff1a97b
1 changed files with 3 additions and 3 deletions
|
|
@ -106,9 +106,9 @@ class AttnProcessor2_0:
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
# IPEX-LLM changes start
|
# IPEX-LLM changes start
|
||||||
if head_dim in [40, 80]:
|
if head_dim in [40, 80]:
|
||||||
import xe_test
|
import xe_addons
|
||||||
hidden_states = xe_test.sdp_non_causal(query, key.contiguous(),
|
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
|
||||||
value.contiguous(), attention_mask)
|
value.contiguous(), attention_mask)
|
||||||
else:
|
else:
|
||||||
scale = 1 / math.sqrt(head_dim)
|
scale = 1 / math.sqrt(head_dim)
|
||||||
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
|
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue