diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 36aa98c1..7a1274d5 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1246,10 +1246,14 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]): def _optimize_post(model, lightweight_bmm=False): try: - from diffusers import DiffusionPipeline + from diffusers import DiffusionPipeline, StableDiffusionXLPipeline if isinstance(model, DiffusionPipeline): from ipex_llm.transformers.models.sd import AttnProcessor2_0 model.unet.set_attn_processor(AttnProcessor2_0()) + + if isinstance(model, StableDiffusionXLPipeline): + from ipex_llm.transformers.models.sd import upcast_vae + model.upcast_vae = MethodType(upcast_vae, model) return model except ModuleNotFoundError: pass diff --git a/python/llm/src/ipex_llm/transformers/models/sd.py b/python/llm/src/ipex_llm/transformers/models/sd.py index 7bd4cf82..50003903 100644 --- a/python/llm/src/ipex_llm/transformers/models/sd.py +++ b/python/llm/src/ipex_llm/transformers/models/sd.py @@ -36,6 +36,7 @@ import math import torch from typing import Optional +from ipex_llm.transformers.utils import get_xpu_device_type from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax from ipex_llm.transformers.models.utils import use_sdp_non_causal from diffusers.models.attention_processor import Attention @@ -148,3 +149,13 @@ class AttnProcessor2_0: hidden_states = hidden_states / attn.rescale_output_factor return hidden_states + + +def upcast_vae(self): + # workaround overflow and ipex's bugs + if get_xpu_device_type(self.vae.post_quant_conv.weight) in ["arc", "flex", "pvc"]: + self.vae.to(torch.bfloat16) + else: + self.vae.decoder.up_blocks.to(torch.bfloat16) + self.vae.decoder.conv_norm_out.to(torch.bfloat16) + self.vae.decoder.conv_out.to(torch.bfloat16)