optimize sdxl again (#12441)
This commit is contained in:
parent
b9abb8a285
commit
cdd41f5e4c
2 changed files with 16 additions and 1 deletions
|
|
@ -1246,10 +1246,14 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
|
|||
|
||||
def _optimize_post(model, lightweight_bmm=False):
|
||||
try:
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
from ipex_llm.transformers.models.sd import AttnProcessor2_0
|
||||
model.unet.set_attn_processor(AttnProcessor2_0())
|
||||
|
||||
if isinstance(model, StableDiffusionXLPipeline):
|
||||
from ipex_llm.transformers.models.sd import upcast_vae
|
||||
model.upcast_vae = MethodType(upcast_vae, model)
|
||||
return model
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ import math
|
|||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from ipex_llm.transformers.utils import get_xpu_device_type
|
||||
from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax
|
||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
|
||||
from diffusers.models.attention_processor import Attention
|
||||
|
|
@ -148,3 +149,13 @@ class AttnProcessor2_0:
|
|||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def upcast_vae(self):
|
||||
# workaround overflow and ipex's bugs
|
||||
if get_xpu_device_type(self.vae.post_quant_conv.weight) in ["arc", "flex", "pvc"]:
|
||||
self.vae.to(torch.bfloat16)
|
||||
else:
|
||||
self.vae.decoder.up_blocks.to(torch.bfloat16)
|
||||
self.vae.decoder.conv_norm_out.to(torch.bfloat16)
|
||||
self.vae.decoder.conv_out.to(torch.bfloat16)
|
||||
|
|
|
|||
Loading…
Reference in a new issue