diff --git a/python/llm/src/ipex_llm/transformers/models/sd15.py b/python/llm/src/ipex_llm/transformers/models/sd15.py index 0d8f3532..ab999d40 100644 --- a/python/llm/src/ipex_llm/transformers/models/sd15.py +++ b/python/llm/src/ipex_llm/transformers/models/sd15.py @@ -106,9 +106,9 @@ class AttnProcessor2_0: # the output of sdp = (batch, num_heads, seq_len, head_dim) # IPEX-LLM changes start if head_dim in [40, 80]: - import xe_test - hidden_states = xe_test.sdp_non_causal(query, key.contiguous(), - value.contiguous(), attention_mask) + import xe_addons + hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(), + value.contiguous(), attention_mask) else: scale = 1 / math.sqrt(head_dim) attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))