From a3322e2a6c49f49ee2c66edafdeaf0e53c25dc9f Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Fri, 26 Jan 2024 18:29:46 +0800 Subject: [PATCH] add fp8 e5 to use_xmx (#10015) --- python/llm/src/bigdl/llm/transformers/models/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 67f4e6e1..5efdf0d9 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -23,7 +23,8 @@ from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type SYM_INT4 = ggml_tensor_qtype["sym_int4"] SYM_INT8 = ggml_tensor_qtype["sym_int8"] -FP8 = ggml_tensor_qtype["fp8"] +FP8E4 = ggml_tensor_qtype["fp8_e4m3"] +FP8E5 = ggml_tensor_qtype["fp8_e5m2"] def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): @@ -297,7 +298,7 @@ def use_xmx(x: torch.Tensor, qtype: int): return ( os.environ.get("BIGDL_LLM_XMX_DISABLED", "0") != "1" and device in ["arc", "flex", "pvc"] - and qtype in [SYM_INT4, SYM_INT8, FP8] + and qtype in [SYM_INT4, SYM_INT8, FP8E4, FP8E5] and ( (device == "pvc" and 1 < x.size(0) <= 16) or