only use quantize kv cache on MTL (#9862)
This commit is contained in:
parent
146076bdb5
commit
36496d60ac
2 changed files with 5 additions and 2 deletions
|
|
@ -18,7 +18,7 @@ import os
|
|||
import torch
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.transformers.utils import get_ipex_version
|
||||
from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type
|
||||
|
||||
|
||||
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
||||
|
|
@ -63,7 +63,8 @@ def quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
|
|||
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
|
||||
return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
|
||||
else:
|
||||
return x.device.type == 'xpu' and hasattr(linear, "qtype") and \
|
||||
return x.device.type == 'xpu' and get_xpu_device_type(x) == "mtl" \
|
||||
and hasattr(linear, "qtype") and \
|
||||
linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -169,6 +169,8 @@ def get_xpu_device_type(x):
|
|||
name = torch.xpu.get_device_name(x.device.index)
|
||||
if name.startswith("Intel(R) Arc(TM) A"):
|
||||
return "arc"
|
||||
elif name.startswith("Intel(R) Arc(TM)"):
|
||||
return "mtl"
|
||||
elif name.startswith("Intel(R) Data Center GPU Flex"):
|
||||
return "flex"
|
||||
elif name.startswith("Intel(R) Data Center GPU Max"):
|
||||
|
|
|
|||
Loading…
Reference in a new issue