small change (#12439)

This commit is contained in:
Yishuo Wang 2024-11-25 14:35:49 +08:00 committed by GitHub
parent be132c4209
commit 8164aed802
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 3 additions and 3 deletions

View file

@ -1248,7 +1248,7 @@ def _optimize_post(model, lightweight_bmm=False):
try:
from diffusers import DiffusionPipeline
if isinstance(model, DiffusionPipeline):
from ipex_llm.transformers.models.sd15 import AttnProcessor2_0
from ipex_llm.transformers.models.sd import AttnProcessor2_0
model.unet.set_attn_processor(AttnProcessor2_0())
return model
except ModuleNotFoundError:

View file

@ -106,8 +106,8 @@ class AttnProcessor2_0:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# IPEX-LLM changes start
# padding head_dim 40 to 64
if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
# padding head_dim 40 to 64
query, key, value = padding_qkv_hd(query, key, value, 40, 64)
if use_sdp_non_causal(head_dim, query.device, query.dtype):

View file

@ -329,7 +329,7 @@ def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
def use_sdp_non_causal(head_dim, device, dtype):
return (
head_dim in [40, 64, 80]
head_dim in [64, 80, 128]
and device.type == "xpu" # GPU
and dtype in [torch.float, torch.half] # fp32/fp16
)