Enable kv cache on arc batch (#10308)

This commit is contained in:
Zhao Changmin 2024-03-12 16:46:04 +08:00 committed by GitHub
parent 5809a3f5fe
commit df2b84f7de

View file

@ -73,11 +73,16 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool:
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None: if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1" return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1"
else: else:
return x.device.type == 'xpu' and get_xpu_device_type(x) == "mtl" \ return x.device.type == 'xpu' and kv_cache_device_check(x) \
and hasattr(linear, "qtype") and \ and hasattr(linear, "qtype") and \
linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"] linear.qtype != ggml_tensor_qtype["fp16"] and linear.qtype != ggml_tensor_qtype["bf16"]
def kv_cache_device_check(x: torch.Tensor) -> bool:
return get_xpu_device_type(x) == "mtl" or \
(get_xpu_device_type(x) == "arc" and 1 < x.size(0) and x.size(0) < 8)
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device): def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device):
max_length = current_length + FP8_KV_ALLOC_LENGTH max_length = current_length + FP8_KV_ALLOC_LENGTH