From 36496d60ac30a0d12d4719578c1a29f126c39d88 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 9 Jan 2024 13:24:02 +0800 Subject: [PATCH] only use quantize kv cache on MTL (#9862) --- python/llm/src/bigdl/llm/transformers/models/utils.py | 5 +++-- python/llm/src/bigdl/llm/transformers/utils.py | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index fbd802ca..41994aec 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -18,7 +18,7 @@ import os import torch from bigdl.llm.utils.common import invalidInputError from bigdl.llm.ggml.quantize import ggml_tensor_qtype -from bigdl.llm.transformers.utils import get_ipex_version +from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): @@ -63,7 +63,8 @@ def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool: if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None: return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1" else: - return x.device.type == 'xpu' and hasattr(linear, "qtype") and \ + return x.device.type == 'xpu' and get_xpu_device_type(x) == "mtl" \ + and hasattr(linear, "qtype") and \ linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"] diff --git a/python/llm/src/bigdl/llm/transformers/utils.py b/python/llm/src/bigdl/llm/transformers/utils.py index 4aabb485..3361b9a7 100644 --- a/python/llm/src/bigdl/llm/transformers/utils.py +++ b/python/llm/src/bigdl/llm/transformers/utils.py @@ -169,6 +169,8 @@ def get_xpu_device_type(x): name = torch.xpu.get_device_name(x.device.index) if name.startswith("Intel(R) Arc(TM) A"): return "arc" + elif name.startswith("Intel(R) Arc(TM)"): + return "mtl" elif name.startswith("Intel(R) Data Center GPU Flex"): return "flex" elif name.startswith("Intel(R) Data Center GPU Max"):