add sdxl and lora-lcm optimization (#12444)
* add sdxl and lora-lcm optimization * fix openjourney speed drop
This commit is contained in:
parent
0e23bd779f
commit
66bd7abae4
3 changed files with 11 additions and 7 deletions
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import DiffusionPipeline, LCMScheduler
|
from diffusers import DiffusionPipeline, LCMScheduler
|
||||||
import ipex_llm
|
from ipex_llm import optimize_model
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
@ -25,8 +25,10 @@ import time
|
||||||
def main(args):
|
def main(args):
|
||||||
pipe = DiffusionPipeline.from_pretrained(
|
pipe = DiffusionPipeline.from_pretrained(
|
||||||
args.repo_id_or_model_path,
|
args.repo_id_or_model_path,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.float16,
|
||||||
).to("xpu")
|
)
|
||||||
|
pipe = optimize_model(pipe, low_bit=None)
|
||||||
|
pipe.to("xpu")
|
||||||
|
|
||||||
# set scheduler
|
# set scheduler
|
||||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
from diffusers import AutoPipelineForText2Image
|
from diffusers import AutoPipelineForText2Image
|
||||||
import torch
|
import torch
|
||||||
import ipex_llm
|
from ipex_llm import optimize_model
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import argparse
|
import argparse
|
||||||
|
|
@ -27,9 +27,11 @@ import time
|
||||||
def main(args):
|
def main(args):
|
||||||
pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
|
pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
|
||||||
args.repo_id_or_model_path,
|
args.repo_id_or_model_path,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.float16,
|
||||||
use_safetensors=True
|
use_safetensors=True
|
||||||
).to("xpu")
|
)
|
||||||
|
pipeline_text2image = optimize_model(pipeline_text2image, low_bit=None)
|
||||||
|
pipeline_text2image.to("xpu")
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# warmup
|
# warmup
|
||||||
|
|
|
||||||
|
|
@ -111,7 +111,7 @@ class AttnProcessor2_0:
|
||||||
# padding head_dim 40 to 64
|
# padding head_dim 40 to 64
|
||||||
query, key, value = padding_qkv_hd(query, key, value, 40, 64)
|
query, key, value = padding_qkv_hd(query, key, value, 40, 64)
|
||||||
|
|
||||||
if use_sdp_non_causal(head_dim, query.device, query.dtype):
|
if use_sdp_non_causal(query.size(-1), query.device, query.dtype):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
|
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
|
||||||
value.contiguous(), attention_mask)
|
value.contiguous(), attention_mask)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue