diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 34815b40..36aa98c1 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1248,7 +1248,7 @@ def _optimize_post(model, lightweight_bmm=False): try: from diffusers import DiffusionPipeline if isinstance(model, DiffusionPipeline): - from ipex_llm.transformers.models.sd15 import AttnProcessor2_0 + from ipex_llm.transformers.models.sd import AttnProcessor2_0 model.unet.set_attn_processor(AttnProcessor2_0()) return model except ModuleNotFoundError: diff --git a/python/llm/src/ipex_llm/transformers/models/sd15.py b/python/llm/src/ipex_llm/transformers/models/sd.py similarity index 99% rename from python/llm/src/ipex_llm/transformers/models/sd15.py rename to python/llm/src/ipex_llm/transformers/models/sd.py index 769b2dd2..7bd4cf82 100644 --- a/python/llm/src/ipex_llm/transformers/models/sd15.py +++ b/python/llm/src/ipex_llm/transformers/models/sd.py @@ -106,8 +106,8 @@ class AttnProcessor2_0: # the output of sdp = (batch, num_heads, seq_len, head_dim) # IPEX-LLM changes start - # padding head_dim 40 to 64 if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]: + # padding head_dim 40 to 64 query, key, value = padding_qkv_hd(query, key, value, 40, 64) if use_sdp_non_causal(head_dim, query.device, query.dtype): diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 3d650bf3..463ff4bc 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -329,7 +329,7 @@ def use_sdp_causal(q_len, kv_len, head_dim, query_states, training): def use_sdp_non_causal(head_dim, device, dtype): return ( - head_dim in [40, 64, 80] + head_dim in [64, 80, 128] and device.type == "xpu" # GPU and dtype in [torch.float, torch.half] # fp32/fp16 )