diff --git a/python/llm/src/ipex_llm/optimize.py b/python/llm/src/ipex_llm/optimize.py index e010646c..d8aa95f6 100644 --- a/python/llm/src/ipex_llm/optimize.py +++ b/python/llm/src/ipex_llm/optimize.py @@ -229,7 +229,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_ f"Unknown load_in_low_bit value: {low_bit}, expected:" f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.") invalidInputError(isinstance(model, torch.nn.Module) or - model.__class__.__name__ == "StableDiffusionPipeline", + "StableDiffusion" in model.__class__.__name__, "model should be an instance of " f"`torch.nn.Module`, but got {type(model)} at last.") # To adapt vLLM models diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index fa164d87..34815b40 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1246,8 +1246,8 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]): def _optimize_post(model, lightweight_bmm=False): try: - from diffusers import StableDiffusionPipeline - if isinstance(model, StableDiffusionPipeline): + from diffusers import DiffusionPipeline + if isinstance(model, DiffusionPipeline): from ipex_llm.transformers.models.sd15 import AttnProcessor2_0 model.unet.set_attn_processor(AttnProcessor2_0()) return model diff --git a/python/llm/src/ipex_llm/transformers/models/sd15.py b/python/llm/src/ipex_llm/transformers/models/sd15.py index 60d657ee..769b2dd2 100644 --- a/python/llm/src/ipex_llm/transformers/models/sd15.py +++ b/python/llm/src/ipex_llm/transformers/models/sd15.py @@ -36,7 +36,8 @@ import math import torch from typing import Optional -from ipex_llm.transformers.models.common import attention_softmax +from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax +from ipex_llm.transformers.models.utils import use_sdp_non_causal from diffusers.models.attention_processor import Attention @@ -105,17 +106,27 @@ class AttnProcessor2_0: # the output of sdp = (batch, num_heads, seq_len, head_dim) # IPEX-LLM changes start - if head_dim in [40, 80]: - import xe_addons - hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(), - value.contiguous(), attention_mask) + # padding head_dim 40 to 64 + if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]: + query, key, value = padding_qkv_hd(query, key, value, 40, 64) + + if use_sdp_non_causal(head_dim, query.device, query.dtype): + import xe_addons + hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(), + value.contiguous(), attention_mask) + else: + scale = 1 / math.sqrt(head_dim) + attn_weights = torch.matmul(query * scale, key.transpose(-1, -2)) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = attention_softmax(attn_weights) + hidden_states = torch.matmul(attn_weights, value) + + hidden_states = hidden_states[:, :, :, :head_dim] else: - scale = 1 / math.sqrt(head_dim) - attn_weights = torch.matmul(query * scale, key.transpose(-1, -2)) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - attn_weights = attention_softmax(attn_weights) - hidden_states = torch.matmul(attn_weights, value) + hidden_states = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) # IPEX-LLM changes end hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1,