This commit is contained in:
Yishuo Wang 2024-09-26 17:15:16 +08:00 committed by GitHub
parent a266528719
commit 669ff1a97b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -106,9 +106,9 @@ class AttnProcessor2_0:
# the output of sdp = (batch, num_heads, seq_len, head_dim) # the output of sdp = (batch, num_heads, seq_len, head_dim)
# IPEX-LLM changes start # IPEX-LLM changes start
if head_dim in [40, 80]: if head_dim in [40, 80]:
import xe_test import xe_addons
hidden_states = xe_test.sdp_non_causal(query, key.contiguous(), hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
value.contiguous(), attention_mask) value.contiguous(), attention_mask)
else: else:
scale = 1 / math.sqrt(head_dim) scale = 1 / math.sqrt(head_dim)
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2)) attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))