small change (#12439)
This commit is contained in:
		
							parent
							
								
									be132c4209
								
							
						
					
					
						commit
						8164aed802
					
				
					 3 changed files with 3 additions and 3 deletions
				
			
		| 
						 | 
					@ -1248,7 +1248,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        from diffusers import DiffusionPipeline
 | 
					        from diffusers import DiffusionPipeline
 | 
				
			||||||
        if isinstance(model, DiffusionPipeline):
 | 
					        if isinstance(model, DiffusionPipeline):
 | 
				
			||||||
            from ipex_llm.transformers.models.sd15 import AttnProcessor2_0
 | 
					            from ipex_llm.transformers.models.sd import AttnProcessor2_0
 | 
				
			||||||
            model.unet.set_attn_processor(AttnProcessor2_0())
 | 
					            model.unet.set_attn_processor(AttnProcessor2_0())
 | 
				
			||||||
            return model
 | 
					            return model
 | 
				
			||||||
    except ModuleNotFoundError:
 | 
					    except ModuleNotFoundError:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -106,8 +106,8 @@ class AttnProcessor2_0:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # the output of sdp = (batch, num_heads, seq_len, head_dim)
 | 
					        # the output of sdp = (batch, num_heads, seq_len, head_dim)
 | 
				
			||||||
        # IPEX-LLM changes start
 | 
					        # IPEX-LLM changes start
 | 
				
			||||||
        # padding head_dim 40 to 64
 | 
					 | 
				
			||||||
        if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
 | 
					        if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
 | 
				
			||||||
 | 
					            # padding head_dim 40 to 64
 | 
				
			||||||
            query, key, value = padding_qkv_hd(query, key, value, 40, 64)
 | 
					            query, key, value = padding_qkv_hd(query, key, value, 40, 64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if use_sdp_non_causal(head_dim, query.device, query.dtype):
 | 
					            if use_sdp_non_causal(head_dim, query.device, query.dtype):
 | 
				
			||||||
| 
						 | 
					@ -329,7 +329,7 @@ def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def use_sdp_non_causal(head_dim, device, dtype):
 | 
					def use_sdp_non_causal(head_dim, device, dtype):
 | 
				
			||||||
    return (
 | 
					    return (
 | 
				
			||||||
        head_dim in [40, 64, 80]
 | 
					        head_dim in [64, 80, 128]
 | 
				
			||||||
        and device.type == "xpu"                # GPU
 | 
					        and device.type == "xpu"                # GPU
 | 
				
			||||||
        and dtype in [torch.float, torch.half]  # fp32/fp16
 | 
					        and dtype in [torch.float, torch.half]  # fp32/fp16
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue