From b44f7adbad3686bf295422fffdf38a3f5163a75c Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Thu, 28 Mar 2024 22:55:48 +0800 Subject: [PATCH] LLM: Disable esimd sdp for PVC GPU when batch size>1 (#10579) * llm: disable esimd sdp for pvc bz>1. * fix logic. * fix: avoid call get device name twice. --- .../llm/src/ipex_llm/transformers/models/utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 53e3b003..ee91dcd7 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -317,11 +317,6 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None): elif query_states.dtype != torch.float16: # esimd_sdp only has optimization for FP16 now return False - elif query_states.shape[0] > 1 and attention_mask is not None: - # for batched input, can't accept attention_mask - # TODO: this check needs some time - if not torch.all(attention_mask.eq(0)): - return False device_name = torch.xpu.get_device_name(query_states.device.index) if device_name.startswith("Intel(R) Arc(TM) A") or \ @@ -333,6 +328,15 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None): else: return False + if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"): + # esimd_sdp not support PVC GPU when batch size > 1 for now + return False + if query_states.shape[0] > 1 and attention_mask is not None: + # for batched input, can't accept attention_mask + # TODO: this check needs some time + if not torch.all(attention_mask.eq(0)): + return False + return True