fix sd1.5 (#12129)
This commit is contained in:
parent
a266528719
commit
669ff1a97b
1 changed files with 3 additions and 3 deletions
|
|
@ -106,8 +106,8 @@ class AttnProcessor2_0:
|
|||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# IPEX-LLM changes start
|
||||
if head_dim in [40, 80]:
|
||||
import xe_test
|
||||
hidden_states = xe_test.sdp_non_causal(query, key.contiguous(),
|
||||
import xe_addons
|
||||
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
|
||||
value.contiguous(), attention_mask)
|
||||
else:
|
||||
scale = 1 / math.sqrt(head_dim)
|
||||
|
|
|
|||
Loading…
Reference in a new issue