LLM: Disable esimd sdp for PVC GPU when batch size>1 (#10579)

* llm: disable esimd sdp for pvc bz>1.

* fix logic.

* fix: avoid call get device name twice.
This commit is contained in:
Cengguang Zhang 2024-03-28 22:55:48 +08:00 committed by GitHub
parent e6c5a6a5e6
commit b44f7adbad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -317,11 +317,6 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
elif query_states.dtype != torch.float16:
# esimd_sdp only has optimization for FP16 now
return False
elif query_states.shape[0] > 1 and attention_mask is not None:
# for batched input, can't accept attention_mask
# TODO: this check needs some time
if not torch.all(attention_mask.eq(0)):
return False
device_name = torch.xpu.get_device_name(query_states.device.index)
if device_name.startswith("Intel(R) Arc(TM) A") or \
@ -333,6 +328,15 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
else:
return False
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"):
# esimd_sdp not support PVC GPU when batch size > 1 for now
return False
if query_states.shape[0] > 1 and attention_mask is not None:
# for batched input, can't accept attention_mask
# TODO: this check needs some time
if not torch.all(attention_mask.eq(0)):
return False
return True