small change (#12439)
This commit is contained in:
parent
be132c4209
commit
8164aed802
3 changed files with 3 additions and 3 deletions
|
|
@ -1248,7 +1248,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
try:
|
||||
from diffusers import DiffusionPipeline
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
from ipex_llm.transformers.models.sd15 import AttnProcessor2_0
|
||||
from ipex_llm.transformers.models.sd import AttnProcessor2_0
|
||||
model.unet.set_attn_processor(AttnProcessor2_0())
|
||||
return model
|
||||
except ModuleNotFoundError:
|
||||
|
|
|
|||
|
|
@ -106,8 +106,8 @@ class AttnProcessor2_0:
|
|||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# IPEX-LLM changes start
|
||||
# padding head_dim 40 to 64
|
||||
if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
|
||||
# padding head_dim 40 to 64
|
||||
query, key, value = padding_qkv_hd(query, key, value, 40, 64)
|
||||
|
||||
if use_sdp_non_causal(head_dim, query.device, query.dtype):
|
||||
|
|
@ -329,7 +329,7 @@ def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
|
|||
|
||||
def use_sdp_non_causal(head_dim, device, dtype):
|
||||
return (
|
||||
head_dim in [40, 64, 80]
|
||||
head_dim in [64, 80, 128]
|
||||
and device.type == "xpu" # GPU
|
||||
and dtype in [torch.float, torch.half] # fp32/fp16
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue