optimize sdxl again (#12441)
This commit is contained in:
parent
b9abb8a285
commit
cdd41f5e4c
2 changed files with 16 additions and 1 deletions
|
|
@ -1246,10 +1246,14 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
|
||||||
|
|
||||||
def _optimize_post(model, lightweight_bmm=False):
|
def _optimize_post(model, lightweight_bmm=False):
|
||||||
try:
|
try:
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
|
||||||
if isinstance(model, DiffusionPipeline):
|
if isinstance(model, DiffusionPipeline):
|
||||||
from ipex_llm.transformers.models.sd import AttnProcessor2_0
|
from ipex_llm.transformers.models.sd import AttnProcessor2_0
|
||||||
model.unet.set_attn_processor(AttnProcessor2_0())
|
model.unet.set_attn_processor(AttnProcessor2_0())
|
||||||
|
|
||||||
|
if isinstance(model, StableDiffusionXLPipeline):
|
||||||
|
from ipex_llm.transformers.models.sd import upcast_vae
|
||||||
|
model.upcast_vae = MethodType(upcast_vae, model)
|
||||||
return model
|
return model
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ import math
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from ipex_llm.transformers.utils import get_xpu_device_type
|
||||||
from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax
|
from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax
|
||||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
|
from ipex_llm.transformers.models.utils import use_sdp_non_causal
|
||||||
from diffusers.models.attention_processor import Attention
|
from diffusers.models.attention_processor import Attention
|
||||||
|
|
@ -148,3 +149,13 @@ class AttnProcessor2_0:
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def upcast_vae(self):
|
||||||
|
# workaround overflow and ipex's bugs
|
||||||
|
if get_xpu_device_type(self.vae.post_quant_conv.weight) in ["arc", "flex", "pvc"]:
|
||||||
|
self.vae.to(torch.bfloat16)
|
||||||
|
else:
|
||||||
|
self.vae.decoder.up_blocks.to(torch.bfloat16)
|
||||||
|
self.vae.decoder.conv_norm_out.to(torch.bfloat16)
|
||||||
|
self.vae.decoder.conv_out.to(torch.bfloat16)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue