From df2b84f7de527d3ec6f9381532338d39614a5b6e Mon Sep 17 00:00:00 2001 From: Zhao Changmin Date: Tue, 12 Mar 2024 16:46:04 +0800 Subject: [PATCH] Enable kv cache on arc batch (#10308) --- python/llm/src/bigdl/llm/transformers/models/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index aa791f26..88e5792f 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -73,11 +73,16 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool: if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None: return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1" else: - return x.device.type == 'xpu' and get_xpu_device_type(x) == "mtl" \ + return x.device.type == 'xpu' and kv_cache_device_check(x) \ and hasattr(linear, "qtype") and \ linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"] +def kv_cache_device_check(x: torch.Tensor) -> bool: + return get_xpu_device_type(x) == "mtl" or \ + (get_xpu_device_type(x) == "arc" and 1 < x.size(0) and x.size(0) < 8) + + def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device): max_length = current_length + FP8_KV_ALLOC_LENGTH