optimize sdxl again (#12441)

This commit is contained in:
Yishuo Wang 2024-11-25 17:46:46 +08:00 committed by GitHub
parent b9abb8a285
commit cdd41f5e4c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 16 additions and 1 deletions

View file

@ -1246,10 +1246,14 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
def _optimize_post(model, lightweight_bmm=False): def _optimize_post(model, lightweight_bmm=False):
try: try:
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
if isinstance(model, DiffusionPipeline): if isinstance(model, DiffusionPipeline):
from ipex_llm.transformers.models.sd import AttnProcessor2_0 from ipex_llm.transformers.models.sd import AttnProcessor2_0
model.unet.set_attn_processor(AttnProcessor2_0()) model.unet.set_attn_processor(AttnProcessor2_0())
if isinstance(model, StableDiffusionXLPipeline):
from ipex_llm.transformers.models.sd import upcast_vae
model.upcast_vae = MethodType(upcast_vae, model)
return model return model
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass

View file

@ -36,6 +36,7 @@ import math
import torch import torch
from typing import Optional from typing import Optional
from ipex_llm.transformers.utils import get_xpu_device_type
from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax
from ipex_llm.transformers.models.utils import use_sdp_non_causal from ipex_llm.transformers.models.utils import use_sdp_non_causal
from diffusers.models.attention_processor import Attention from diffusers.models.attention_processor import Attention
@ -148,3 +149,13 @@ class AttnProcessor2_0:
hidden_states = hidden_states / attn.rescale_output_factor hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states return hidden_states
def upcast_vae(self):
# workaround overflow and ipex's bugs
if get_xpu_device_type(self.vae.post_quant_conv.weight) in ["arc", "flex", "pvc"]:
self.vae.to(torch.bfloat16)
else:
self.vae.decoder.up_blocks.to(torch.bfloat16)
self.vae.decoder.conv_norm_out.to(torch.bfloat16)
self.vae.decoder.conv_out.to(torch.bfloat16)