add sdxl and lora-lcm optimization (#12444)

* add sdxl and lora-lcm optimization

* fix openjourney speed drop
This commit is contained in:
Jinhe 2024-11-26 11:38:09 +08:00 committed by GitHub
parent 0e23bd779f
commit 66bd7abae4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 11 additions and 7 deletions

View file

@ -17,7 +17,7 @@
import torch
from diffusers import DiffusionPipeline, LCMScheduler
import ipex_llm
from ipex_llm import optimize_model
import argparse
import time
@ -25,8 +25,10 @@ import time
def main(args):
pipe = DiffusionPipeline.from_pretrained(
args.repo_id_or_model_path,
torch_dtype=torch.bfloat16,
).to("xpu")
torch_dtype=torch.float16,
)
pipe = optimize_model(pipe, low_bit=None)
pipe.to("xpu")
# set scheduler
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)

View file

@ -17,7 +17,7 @@
from diffusers import AutoPipelineForText2Image
import torch
import ipex_llm
from ipex_llm import optimize_model
import numpy as np
from PIL import Image
import argparse
@ -27,9 +27,11 @@ import time
def main(args):
pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
args.repo_id_or_model_path,
torch_dtype=torch.bfloat16,
torch_dtype=torch.float16,
use_safetensors=True
).to("xpu")
)
pipeline_text2image = optimize_model(pipeline_text2image, low_bit=None)
pipeline_text2image.to("xpu")
with torch.inference_mode():
# warmup

View file

@ -111,7 +111,7 @@ class AttnProcessor2_0:
# padding head_dim 40 to 64
query, key, value = padding_qkv_hd(query, key, value, 40, 64)
if use_sdp_non_causal(head_dim, query.device, query.dtype):
if use_sdp_non_causal(query.size(-1), query.device, query.dtype):
import xe_addons
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
value.contiguous(), attention_mask)