small change (#12439)
This commit is contained in:
		
							parent
							
								
									be132c4209
								
							
						
					
					
						commit
						8164aed802
					
				
					 3 changed files with 3 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -1248,7 +1248,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
    try:
 | 
			
		||||
        from diffusers import DiffusionPipeline
 | 
			
		||||
        if isinstance(model, DiffusionPipeline):
 | 
			
		||||
            from ipex_llm.transformers.models.sd15 import AttnProcessor2_0
 | 
			
		||||
            from ipex_llm.transformers.models.sd import AttnProcessor2_0
 | 
			
		||||
            model.unet.set_attn_processor(AttnProcessor2_0())
 | 
			
		||||
            return model
 | 
			
		||||
    except ModuleNotFoundError:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -106,8 +106,8 @@ class AttnProcessor2_0:
 | 
			
		|||
 | 
			
		||||
        # the output of sdp = (batch, num_heads, seq_len, head_dim)
 | 
			
		||||
        # IPEX-LLM changes start
 | 
			
		||||
        # padding head_dim 40 to 64
 | 
			
		||||
        if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
 | 
			
		||||
            # padding head_dim 40 to 64
 | 
			
		||||
            query, key, value = padding_qkv_hd(query, key, value, 40, 64)
 | 
			
		||||
 | 
			
		||||
            if use_sdp_non_causal(head_dim, query.device, query.dtype):
 | 
			
		||||
| 
						 | 
				
			
			@ -329,7 +329,7 @@ def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
 | 
			
		|||
 | 
			
		||||
def use_sdp_non_causal(head_dim, device, dtype):
 | 
			
		||||
    return (
 | 
			
		||||
        head_dim in [40, 64, 80]
 | 
			
		||||
        head_dim in [64, 80, 128]
 | 
			
		||||
        and device.type == "xpu"                # GPU
 | 
			
		||||
        and dtype in [torch.float, torch.half]  # fp32/fp16
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue