fix and optimize sd (#12436)
This commit is contained in:
parent
f41405368a
commit
be132c4209
3 changed files with 25 additions and 14 deletions
|
|
@ -229,7 +229,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
|
|||
f"Unknown load_in_low_bit value: {low_bit}, expected:"
|
||||
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
||||
invalidInputError(isinstance(model, torch.nn.Module) or
|
||||
model.__class__.__name__ == "StableDiffusionPipeline",
|
||||
"StableDiffusion" in model.__class__.__name__,
|
||||
"model should be an instance of "
|
||||
f"`torch.nn.Module`, but got {type(model)} at last.")
|
||||
# To adapt vLLM models
|
||||
|
|
|
|||
|
|
@ -1246,8 +1246,8 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
|
|||
|
||||
def _optimize_post(model, lightweight_bmm=False):
|
||||
try:
|
||||
from diffusers import StableDiffusionPipeline
|
||||
if isinstance(model, StableDiffusionPipeline):
|
||||
from diffusers import DiffusionPipeline
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
from ipex_llm.transformers.models.sd15 import AttnProcessor2_0
|
||||
model.unet.set_attn_processor(AttnProcessor2_0())
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -36,7 +36,8 @@ import math
|
|||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from ipex_llm.transformers.models.common import attention_softmax
|
||||
from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax
|
||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
|
||||
from diffusers.models.attention_processor import Attention
|
||||
|
||||
|
||||
|
|
@ -105,7 +106,11 @@ class AttnProcessor2_0:
|
|||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# IPEX-LLM changes start
|
||||
if head_dim in [40, 80]:
|
||||
# padding head_dim 40 to 64
|
||||
if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
|
||||
query, key, value = padding_qkv_hd(query, key, value, 40, 64)
|
||||
|
||||
if use_sdp_non_causal(head_dim, query.device, query.dtype):
|
||||
import xe_addons
|
||||
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
|
||||
value.contiguous(), attention_mask)
|
||||
|
|
@ -116,6 +121,12 @@ class AttnProcessor2_0:
|
|||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
hidden_states = torch.matmul(attn_weights, value)
|
||||
|
||||
hidden_states = hidden_states[:, :, :, :head_dim]
|
||||
else:
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
# IPEX-LLM changes end
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1,
|
||||
|
|
|
|||
Loading…
Reference in a new issue