LLM: Disable esimd sdp for PVC GPU when batch size>1 (#10579)
* llm: disable esimd sdp for pvc bz>1. * fix logic. * fix: avoid call get device name twice.
This commit is contained in:
parent
e6c5a6a5e6
commit
b44f7adbad
1 changed files with 9 additions and 5 deletions
|
|
@ -317,11 +317,6 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
|
|||
elif query_states.dtype != torch.float16:
|
||||
# esimd_sdp only has optimization for FP16 now
|
||||
return False
|
||||
elif query_states.shape[0] > 1 and attention_mask is not None:
|
||||
# for batched input, can't accept attention_mask
|
||||
# TODO: this check needs some time
|
||||
if not torch.all(attention_mask.eq(0)):
|
||||
return False
|
||||
|
||||
device_name = torch.xpu.get_device_name(query_states.device.index)
|
||||
if device_name.startswith("Intel(R) Arc(TM) A") or \
|
||||
|
|
@ -333,6 +328,15 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
|
|||
else:
|
||||
return False
|
||||
|
||||
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"):
|
||||
# esimd_sdp not support PVC GPU when batch size > 1 for now
|
||||
return False
|
||||
if query_states.shape[0] > 1 and attention_mask is not None:
|
||||
# for batched input, can't accept attention_mask
|
||||
# TODO: this check needs some time
|
||||
if not torch.all(attention_mask.eq(0)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue