add sdxl and lora-lcm optimization (#12444)
* add sdxl and lora-lcm optimization * fix openjourney speed drop
This commit is contained in:
parent
0e23bd779f
commit
66bd7abae4
3 changed files with 11 additions and 7 deletions
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, LCMScheduler
|
||||
import ipex_llm
|
||||
from ipex_llm import optimize_model
|
||||
import argparse
|
||||
import time
|
||||
|
||||
|
|
@ -25,8 +25,10 @@ import time
|
|||
def main(args):
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
args.repo_id_or_model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("xpu")
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = optimize_model(pipe, low_bit=None)
|
||||
pipe.to("xpu")
|
||||
|
||||
# set scheduler
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
import ipex_llm
|
||||
from ipex_llm import optimize_model
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import argparse
|
||||
|
|
@ -27,9 +27,11 @@ import time
|
|||
def main(args):
|
||||
pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
|
||||
args.repo_id_or_model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True
|
||||
).to("xpu")
|
||||
)
|
||||
pipeline_text2image = optimize_model(pipeline_text2image, low_bit=None)
|
||||
pipeline_text2image.to("xpu")
|
||||
|
||||
with torch.inference_mode():
|
||||
# warmup
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ class AttnProcessor2_0:
|
|||
# padding head_dim 40 to 64
|
||||
query, key, value = padding_qkv_hd(query, key, value, 40, 64)
|
||||
|
||||
if use_sdp_non_causal(head_dim, query.device, query.dtype):
|
||||
if use_sdp_non_causal(query.size(-1), query.device, query.dtype):
|
||||
import xe_addons
|
||||
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
|
||||
value.contiguous(), attention_mask)
|
||||
|
|
|
|||
Loading…
Reference in a new issue