only use quantize kv cache on MTL (#9862)

This commit is contained in:
Yishuo Wang 2024-01-09 13:24:02 +08:00 committed by GitHub
parent 146076bdb5
commit 36496d60ac
2 changed files with 5 additions and 2 deletions

View file

@ -18,7 +18,7 @@ import os
import torch
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.utils import get_ipex_version
from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
@ -63,7 +63,8 @@ def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
else:
return x.device.type == 'xpu' and hasattr(linear, "qtype") and \
return x.device.type == 'xpu' and get_xpu_device_type(x) == "mtl" \
and hasattr(linear, "qtype") and \
linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]

View file

@ -169,6 +169,8 @@ def get_xpu_device_type(x):
name = torch.xpu.get_device_name(x.device.index)
if name.startswith("Intel(R) Arc(TM) A"):
return "arc"
elif name.startswith("Intel(R) Arc(TM)"):
return "mtl"
elif name.startswith("Intel(R) Data Center GPU Flex"):
return "flex"
elif name.startswith("Intel(R) Data Center GPU Max"):