LLM: Disable esimd sdp for PVC GPU when batch size>1 (#10579)
* llm: disable esimd sdp for pvc bz>1. * fix logic. * fix: avoid call get device name twice.
This commit is contained in:
		
							parent
							
								
									e6c5a6a5e6
								
							
						
					
					
						commit
						b44f7adbad
					
				
					 1 changed files with 9 additions and 5 deletions
				
			
		| 
						 | 
				
			
			@ -317,11 +317,6 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
 | 
			
		|||
    elif query_states.dtype != torch.float16:
 | 
			
		||||
        # esimd_sdp only has optimization for FP16 now
 | 
			
		||||
        return False
 | 
			
		||||
    elif query_states.shape[0] > 1 and attention_mask is not None:
 | 
			
		||||
        # for batched input, can't accept attention_mask
 | 
			
		||||
        # TODO: this check needs some time
 | 
			
		||||
        if not torch.all(attention_mask.eq(0)):
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    device_name = torch.xpu.get_device_name(query_states.device.index)
 | 
			
		||||
    if device_name.startswith("Intel(R) Arc(TM) A") or \
 | 
			
		||||
| 
						 | 
				
			
			@ -333,6 +328,15 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
 | 
			
		|||
    else:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"):
 | 
			
		||||
        # esimd_sdp not support PVC GPU when batch size > 1 for now
 | 
			
		||||
        return False
 | 
			
		||||
    if query_states.shape[0] > 1 and attention_mask is not None:
 | 
			
		||||
        # for batched input, can't accept attention_mask
 | 
			
		||||
        # TODO: this check needs some time
 | 
			
		||||
        if not torch.all(attention_mask.eq(0)):
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue