add sdxl and lora-lcm optimization (#12444)
* add sdxl and lora-lcm optimization * fix openjourney speed drop
This commit is contained in:
		
							parent
							
								
									0e23bd779f
								
							
						
					
					
						commit
						66bd7abae4
					
				
					 3 changed files with 11 additions and 7 deletions
				
			
		| 
						 | 
					@ -17,7 +17,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from diffusers import DiffusionPipeline, LCMScheduler
 | 
					from diffusers import DiffusionPipeline, LCMScheduler
 | 
				
			||||||
import ipex_llm
 | 
					from ipex_llm import optimize_model
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,8 +25,10 @@ import time
 | 
				
			||||||
def main(args):
 | 
					def main(args):
 | 
				
			||||||
    pipe = DiffusionPipeline.from_pretrained(
 | 
					    pipe = DiffusionPipeline.from_pretrained(
 | 
				
			||||||
        args.repo_id_or_model_path,
 | 
					        args.repo_id_or_model_path,
 | 
				
			||||||
        torch_dtype=torch.bfloat16,
 | 
					        torch_dtype=torch.float16,
 | 
				
			||||||
    ).to("xpu")
 | 
					    )
 | 
				
			||||||
 | 
					    pipe = optimize_model(pipe, low_bit=None)
 | 
				
			||||||
 | 
					    pipe.to("xpu")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # set scheduler
 | 
					    # set scheduler
 | 
				
			||||||
    pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
 | 
					    pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,7 +17,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from diffusers import AutoPipelineForText2Image
 | 
					from diffusers import AutoPipelineForText2Image
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import ipex_llm
 | 
					from ipex_llm import optimize_model
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from PIL import Image
 | 
					from PIL import Image
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
| 
						 | 
					@ -27,9 +27,11 @@ import time
 | 
				
			||||||
def main(args):
 | 
					def main(args):
 | 
				
			||||||
    pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
 | 
					    pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
 | 
				
			||||||
        args.repo_id_or_model_path, 
 | 
					        args.repo_id_or_model_path, 
 | 
				
			||||||
        torch_dtype=torch.bfloat16, 
 | 
					        torch_dtype=torch.float16, 
 | 
				
			||||||
        use_safetensors=True
 | 
					        use_safetensors=True
 | 
				
			||||||
    ).to("xpu")
 | 
					    )
 | 
				
			||||||
 | 
					    pipeline_text2image = optimize_model(pipeline_text2image, low_bit=None)
 | 
				
			||||||
 | 
					    pipeline_text2image.to("xpu")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with torch.inference_mode():
 | 
					    with torch.inference_mode():
 | 
				
			||||||
        # warmup
 | 
					        # warmup
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -111,7 +111,7 @@ class AttnProcessor2_0:
 | 
				
			||||||
            # padding head_dim 40 to 64
 | 
					            # padding head_dim 40 to 64
 | 
				
			||||||
            query, key, value = padding_qkv_hd(query, key, value, 40, 64)
 | 
					            query, key, value = padding_qkv_hd(query, key, value, 40, 64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if use_sdp_non_causal(head_dim, query.device, query.dtype):
 | 
					            if use_sdp_non_causal(query.size(-1), query.device, query.dtype):
 | 
				
			||||||
                import xe_addons
 | 
					                import xe_addons
 | 
				
			||||||
                hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
 | 
					                hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
 | 
				
			||||||
                                                         value.contiguous(), attention_mask)
 | 
					                                                         value.contiguous(), attention_mask)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue