diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index ee91dcd7..222634ef 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -73,7 +73,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states): def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool: if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None: - return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1" + return int(os.environ["BIGDL_QUANTIZE_KV_CACHE"]) == 1 else: return x.device.type == 'xpu' and kv_cache_device_check(x) \ and hasattr(linear, "qtype") and \ @@ -82,7 +82,8 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool: def kv_cache_device_check(x: torch.Tensor) -> bool: return get_xpu_device_type(x) == "mtl" or \ - (get_xpu_device_type(x) == "arc" and 1 < x.size(0) and x.size(0) < 8) + ((get_xpu_device_type(x) == "arc" or get_xpu_device_type(x) == "flex") and \ + 1 < x.size(0) and x.size(0) <= 8) def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, new_layout=False):