fix and optimize sd (#12436)
This commit is contained in:
parent
f41405368a
commit
be132c4209
3 changed files with 25 additions and 14 deletions
|
|
@ -229,7 +229,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
|
||||||
f"Unknown load_in_low_bit value: {low_bit}, expected:"
|
f"Unknown load_in_low_bit value: {low_bit}, expected:"
|
||||||
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
||||||
invalidInputError(isinstance(model, torch.nn.Module) or
|
invalidInputError(isinstance(model, torch.nn.Module) or
|
||||||
model.__class__.__name__ == "StableDiffusionPipeline",
|
"StableDiffusion" in model.__class__.__name__,
|
||||||
"model should be an instance of "
|
"model should be an instance of "
|
||||||
f"`torch.nn.Module`, but got {type(model)} at last.")
|
f"`torch.nn.Module`, but got {type(model)} at last.")
|
||||||
# To adapt vLLM models
|
# To adapt vLLM models
|
||||||
|
|
|
||||||
|
|
@ -1246,8 +1246,8 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
|
||||||
|
|
||||||
def _optimize_post(model, lightweight_bmm=False):
|
def _optimize_post(model, lightweight_bmm=False):
|
||||||
try:
|
try:
|
||||||
from diffusers import StableDiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
if isinstance(model, StableDiffusionPipeline):
|
if isinstance(model, DiffusionPipeline):
|
||||||
from ipex_llm.transformers.models.sd15 import AttnProcessor2_0
|
from ipex_llm.transformers.models.sd15 import AttnProcessor2_0
|
||||||
model.unet.set_attn_processor(AttnProcessor2_0())
|
model.unet.set_attn_processor(AttnProcessor2_0())
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,8 @@ import math
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ipex_llm.transformers.models.common import attention_softmax
|
from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax
|
||||||
|
from ipex_llm.transformers.models.utils import use_sdp_non_causal
|
||||||
from diffusers.models.attention_processor import Attention
|
from diffusers.models.attention_processor import Attention
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -105,7 +106,11 @@ class AttnProcessor2_0:
|
||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
# IPEX-LLM changes start
|
# IPEX-LLM changes start
|
||||||
if head_dim in [40, 80]:
|
# padding head_dim 40 to 64
|
||||||
|
if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
|
||||||
|
query, key, value = padding_qkv_hd(query, key, value, 40, 64)
|
||||||
|
|
||||||
|
if use_sdp_non_causal(head_dim, query.device, query.dtype):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
|
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
|
||||||
value.contiguous(), attention_mask)
|
value.contiguous(), attention_mask)
|
||||||
|
|
@ -116,6 +121,12 @@ class AttnProcessor2_0:
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
attn_weights = attention_softmax(attn_weights)
|
attn_weights = attention_softmax(attn_weights)
|
||||||
hidden_states = torch.matmul(attn_weights, value)
|
hidden_states = torch.matmul(attn_weights, value)
|
||||||
|
|
||||||
|
hidden_states = hidden_states[:, :, :, :head_dim]
|
||||||
|
else:
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
# IPEX-LLM changes end
|
# IPEX-LLM changes end
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1,
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue