optimize sdxl again (#12441)
This commit is contained in:
		
							parent
							
								
									b9abb8a285
								
							
						
					
					
						commit
						cdd41f5e4c
					
				
					 2 changed files with 16 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -1246,10 +1246,14 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
 | 
			
		|||
 | 
			
		||||
def _optimize_post(model, lightweight_bmm=False):
 | 
			
		||||
    try:
 | 
			
		||||
        from diffusers import DiffusionPipeline
 | 
			
		||||
        from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
 | 
			
		||||
        if isinstance(model, DiffusionPipeline):
 | 
			
		||||
            from ipex_llm.transformers.models.sd import AttnProcessor2_0
 | 
			
		||||
            model.unet.set_attn_processor(AttnProcessor2_0())
 | 
			
		||||
 | 
			
		||||
            if isinstance(model, StableDiffusionXLPipeline):
 | 
			
		||||
                from ipex_llm.transformers.models.sd import upcast_vae
 | 
			
		||||
                model.upcast_vae = MethodType(upcast_vae, model)
 | 
			
		||||
            return model
 | 
			
		||||
    except ModuleNotFoundError:
 | 
			
		||||
        pass
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -36,6 +36,7 @@ import math
 | 
			
		|||
import torch
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from ipex_llm.transformers.utils import get_xpu_device_type
 | 
			
		||||
from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
 | 
			
		||||
from diffusers.models.attention_processor import Attention
 | 
			
		||||
| 
						 | 
				
			
			@ -148,3 +149,13 @@ class AttnProcessor2_0:
 | 
			
		|||
        hidden_states = hidden_states / attn.rescale_output_factor
 | 
			
		||||
 | 
			
		||||
        return hidden_states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def upcast_vae(self):
 | 
			
		||||
    # workaround overflow and ipex's bugs
 | 
			
		||||
    if get_xpu_device_type(self.vae.post_quant_conv.weight) in ["arc", "flex", "pvc"]:
 | 
			
		||||
        self.vae.to(torch.bfloat16)
 | 
			
		||||
    else:
 | 
			
		||||
        self.vae.decoder.up_blocks.to(torch.bfloat16)
 | 
			
		||||
        self.vae.decoder.conv_norm_out.to(torch.bfloat16)
 | 
			
		||||
        self.vae.decoder.conv_out.to(torch.bfloat16)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue