fix and optimize sd (#12436)

This commit is contained in:
Yishuo Wang 2024-11-25 14:09:48 +08:00 committed by GitHub
parent f41405368a
commit be132c4209
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 25 additions and 14 deletions

View file

@ -229,7 +229,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
f"Unknown load_in_low_bit value: {low_bit}, expected:"
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
invalidInputError(isinstance(model, torch.nn.Module) or
model.__class__.__name__ == "StableDiffusionPipeline",
"StableDiffusion" in model.__class__.__name__,
"model should be an instance of "
f"`torch.nn.Module`, but got {type(model)} at last.")
# To adapt vLLM models

View file

@ -1246,8 +1246,8 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
def _optimize_post(model, lightweight_bmm=False):
try:
from diffusers import StableDiffusionPipeline
if isinstance(model, StableDiffusionPipeline):
from diffusers import DiffusionPipeline
if isinstance(model, DiffusionPipeline):
from ipex_llm.transformers.models.sd15 import AttnProcessor2_0
model.unet.set_attn_processor(AttnProcessor2_0())
return model

View file

@ -36,7 +36,8 @@ import math
import torch
from typing import Optional
from ipex_llm.transformers.models.common import attention_softmax
from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax
from ipex_llm.transformers.models.utils import use_sdp_non_causal
from diffusers.models.attention_processor import Attention
@ -105,17 +106,27 @@ class AttnProcessor2_0:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# IPEX-LLM changes start
if head_dim in [40, 80]:
import xe_addons
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
value.contiguous(), attention_mask)
# padding head_dim 40 to 64
if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
query, key, value = padding_qkv_hd(query, key, value, 40, 64)
if use_sdp_non_causal(head_dim, query.device, query.dtype):
import xe_addons
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
value.contiguous(), attention_mask)
else:
scale = 1 / math.sqrt(head_dim)
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = attention_softmax(attn_weights)
hidden_states = torch.matmul(attn_weights, value)
hidden_states = hidden_states[:, :, :, :head_dim]
else:
scale = 1 / math.sqrt(head_dim)
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = attention_softmax(attn_weights)
hidden_states = torch.matmul(attn_weights, value)
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# IPEX-LLM changes end
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1,