LLM: Disable esimd sdp for PVC GPU when batch size>1 (#10579)
* llm: disable esimd sdp for pvc bz>1. * fix logic. * fix: avoid call get device name twice.
This commit is contained in:
parent
e6c5a6a5e6
commit
b44f7adbad
1 changed files with 9 additions and 5 deletions
|
|
@ -317,11 +317,6 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
|
||||||
elif query_states.dtype != torch.float16:
|
elif query_states.dtype != torch.float16:
|
||||||
# esimd_sdp only has optimization for FP16 now
|
# esimd_sdp only has optimization for FP16 now
|
||||||
return False
|
return False
|
||||||
elif query_states.shape[0] > 1 and attention_mask is not None:
|
|
||||||
# for batched input, can't accept attention_mask
|
|
||||||
# TODO: this check needs some time
|
|
||||||
if not torch.all(attention_mask.eq(0)):
|
|
||||||
return False
|
|
||||||
|
|
||||||
device_name = torch.xpu.get_device_name(query_states.device.index)
|
device_name = torch.xpu.get_device_name(query_states.device.index)
|
||||||
if device_name.startswith("Intel(R) Arc(TM) A") or \
|
if device_name.startswith("Intel(R) Arc(TM) A") or \
|
||||||
|
|
@ -333,6 +328,15 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"):
|
||||||
|
# esimd_sdp not support PVC GPU when batch size > 1 for now
|
||||||
|
return False
|
||||||
|
if query_states.shape[0] > 1 and attention_mask is not None:
|
||||||
|
# for batched input, can't accept attention_mask
|
||||||
|
# TODO: this check needs some time
|
||||||
|
if not torch.all(attention_mask.eq(0)):
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue