diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 67f4e6e1..5efdf0d9 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -23,7 +23,8 @@ from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type SYM_INT4 = ggml_tensor_qtype["sym_int4"] SYM_INT8 = ggml_tensor_qtype["sym_int8"] -FP8 = ggml_tensor_qtype["fp8"] +FP8E4 = ggml_tensor_qtype["fp8_e4m3"] +FP8E5 = ggml_tensor_qtype["fp8_e5m2"] def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): @@ -297,7 +298,7 @@ def use_xmx(x: torch.Tensor, qtype: int): return ( os.environ.get("BIGDL_LLM_XMX_DISABLED", "0") != "1" and device in ["arc", "flex", "pvc"] - and qtype in [SYM_INT4, SYM_INT8, FP8] + and qtype in [SYM_INT4, SYM_INT8, FP8E4, FP8E5] and ( (device == "pvc" and 1 < x.size(0) <= 16) or