small change (#12439)
This commit is contained in:
parent
be132c4209
commit
8164aed802
3 changed files with 3 additions and 3 deletions
|
|
@ -1248,7 +1248,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
try:
|
try:
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
if isinstance(model, DiffusionPipeline):
|
if isinstance(model, DiffusionPipeline):
|
||||||
from ipex_llm.transformers.models.sd15 import AttnProcessor2_0
|
from ipex_llm.transformers.models.sd import AttnProcessor2_0
|
||||||
model.unet.set_attn_processor(AttnProcessor2_0())
|
model.unet.set_attn_processor(AttnProcessor2_0())
|
||||||
return model
|
return model
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
|
|
|
||||||
|
|
@ -106,8 +106,8 @@ class AttnProcessor2_0:
|
||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
# IPEX-LLM changes start
|
# IPEX-LLM changes start
|
||||||
# padding head_dim 40 to 64
|
|
||||||
if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
|
if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
|
||||||
|
# padding head_dim 40 to 64
|
||||||
query, key, value = padding_qkv_hd(query, key, value, 40, 64)
|
query, key, value = padding_qkv_hd(query, key, value, 40, 64)
|
||||||
|
|
||||||
if use_sdp_non_causal(head_dim, query.device, query.dtype):
|
if use_sdp_non_causal(head_dim, query.device, query.dtype):
|
||||||
|
|
@ -329,7 +329,7 @@ def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
|
||||||
|
|
||||||
def use_sdp_non_causal(head_dim, device, dtype):
|
def use_sdp_non_causal(head_dim, device, dtype):
|
||||||
return (
|
return (
|
||||||
head_dim in [40, 64, 80]
|
head_dim in [64, 80, 128]
|
||||||
and device.type == "xpu" # GPU
|
and device.type == "xpu" # GPU
|
||||||
and dtype in [torch.float, torch.half] # fp32/fp16
|
and dtype in [torch.float, torch.half] # fp32/fp16
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue