From 7429ea0606e617661414b00aa0aea063ebe13949 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 29 Aug 2023 17:27:18 +0800 Subject: [PATCH] [LLM] support transformer int4 + amx int4 (#8838) --- python/llm/src/bigdl/llm/transformers/linear_quant.py | 6 ++++-- python/llm/src/bigdl/llm/utils/isa_checker.py | 6 ++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/linear_quant.py b/python/llm/src/bigdl/llm/transformers/linear_quant.py index f7b4bab1..09e7b666 100644 --- a/python/llm/src/bigdl/llm/transformers/linear_quant.py +++ b/python/llm/src/bigdl/llm/transformers/linear_quant.py @@ -51,10 +51,11 @@ from torch import Tensor, device, dtype, nn T = TypeVar("T", bound="torch.nn.Module") import bigdl.llm.ggml.model.llama.llama_cpp as ggml -from bigdl.llm.utils.isa_checker import is_server +from bigdl.llm.utils.isa_checker import is_server, is_spr import ctypes from bigdl.llm.ggml.quantize import ggml_tensor_qtype IS_SERVER = is_server() +IS_SPR = is_spr() TORCH_LINEAR_THRESHOLD = 96 SYM_INT4 = ggml_tensor_qtype["sym_int4"] @@ -256,7 +257,8 @@ class LinearQuant(nn.Linear): else: # CPU logic # todo may need to set a different number on different platforms - if IS_SERVER and self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD: + if IS_SERVER and (not IS_SPR) and \ + self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD: x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length) result = F.linear(x, x0_fp32, self.bias) else: diff --git a/python/llm/src/bigdl/llm/utils/isa_checker.py b/python/llm/src/bigdl/llm/utils/isa_checker.py index 0118506b..0e913dc2 100644 --- a/python/llm/src/bigdl/llm/utils/isa_checker.py +++ b/python/llm/src/bigdl/llm/utils/isa_checker.py @@ -79,3 +79,9 @@ def check_avx512_vnni(): def is_server(): return check_avx512_vnni() + + +# todo: use cpuid to check SPR +# note: now only SPR supports both avxvnni and avx512vnni +def is_spr(): + return check_avx_vnni() and check_avx512_vnni()