fix and optimize sd (#12436)
This commit is contained in:
		
							parent
							
								
									f41405368a
								
							
						
					
					
						commit
						be132c4209
					
				
					 3 changed files with 25 additions and 14 deletions
				
			
		| 
						 | 
				
			
			@ -229,7 +229,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
 | 
			
		|||
                      f"Unknown load_in_low_bit value: {low_bit}, expected:"
 | 
			
		||||
                      f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
 | 
			
		||||
    invalidInputError(isinstance(model, torch.nn.Module) or
 | 
			
		||||
                      model.__class__.__name__ == "StableDiffusionPipeline",
 | 
			
		||||
                      "StableDiffusion" in model.__class__.__name__,
 | 
			
		||||
                      "model should be an instance of "
 | 
			
		||||
                      f"`torch.nn.Module`, but got {type(model)} at last.")
 | 
			
		||||
    # To adapt vLLM models
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1246,8 +1246,8 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):
 | 
			
		|||
 | 
			
		||||
def _optimize_post(model, lightweight_bmm=False):
 | 
			
		||||
    try:
 | 
			
		||||
        from diffusers import StableDiffusionPipeline
 | 
			
		||||
        if isinstance(model, StableDiffusionPipeline):
 | 
			
		||||
        from diffusers import DiffusionPipeline
 | 
			
		||||
        if isinstance(model, DiffusionPipeline):
 | 
			
		||||
            from ipex_llm.transformers.models.sd15 import AttnProcessor2_0
 | 
			
		||||
            model.unet.set_attn_processor(AttnProcessor2_0())
 | 
			
		||||
            return model
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -36,7 +36,8 @@ import math
 | 
			
		|||
import torch
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from ipex_llm.transformers.models.common import attention_softmax
 | 
			
		||||
from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_sdp_non_causal
 | 
			
		||||
from diffusers.models.attention_processor import Attention
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -105,17 +106,27 @@ class AttnProcessor2_0:
 | 
			
		|||
 | 
			
		||||
        # the output of sdp = (batch, num_heads, seq_len, head_dim)
 | 
			
		||||
        # IPEX-LLM changes start
 | 
			
		||||
        if head_dim in [40, 80]:
 | 
			
		||||
            import xe_addons
 | 
			
		||||
            hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
 | 
			
		||||
                                                     value.contiguous(), attention_mask)
 | 
			
		||||
        # padding head_dim 40 to 64
 | 
			
		||||
        if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
 | 
			
		||||
            query, key, value = padding_qkv_hd(query, key, value, 40, 64)
 | 
			
		||||
 | 
			
		||||
            if use_sdp_non_causal(head_dim, query.device, query.dtype):
 | 
			
		||||
                import xe_addons
 | 
			
		||||
                hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
 | 
			
		||||
                                                         value.contiguous(), attention_mask)
 | 
			
		||||
            else:
 | 
			
		||||
                scale = 1 / math.sqrt(head_dim)
 | 
			
		||||
                attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
 | 
			
		||||
                if attention_mask is not None:
 | 
			
		||||
                    attn_weights = attn_weights + attention_mask
 | 
			
		||||
                attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
                hidden_states = torch.matmul(attn_weights, value)
 | 
			
		||||
 | 
			
		||||
            hidden_states = hidden_states[:, :, :, :head_dim]
 | 
			
		||||
        else:
 | 
			
		||||
            scale = 1 / math.sqrt(head_dim)
 | 
			
		||||
            attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attn_weights = attn_weights + attention_mask
 | 
			
		||||
            attn_weights = attention_softmax(attn_weights)
 | 
			
		||||
            hidden_states = torch.matmul(attn_weights, value)
 | 
			
		||||
            hidden_states = torch.nn.functional.scaled_dot_product_attention(
 | 
			
		||||
                query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
 | 
			
		||||
            )
 | 
			
		||||
        # IPEX-LLM changes end
 | 
			
		||||
 | 
			
		||||
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue