diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 8337d345..1aaf9e61 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -55,11 +55,9 @@ from bigdl.llm.transformers.utils import get_autocast_dtype T = TypeVar("T", bound="torch.nn.Module") import bigdl.llm.ggml.model.llama.llama_cpp as ggml -from bigdl.llm.utils.isa_checker import is_server, is_spr import ctypes from bigdl.llm.ggml.quantize import ggml_tensor_qtype -IS_SERVER = is_server() -IS_SPR = is_spr() + TORCH_LINEAR_THRESHOLD = int(os.getenv("BIGDL_LLM_LINEAR_THRESHOLD", "512")) SYM_INT4 = ggml_tensor_qtype["sym_int4"] ASYM_INT4 = ggml_tensor_qtype["asym_int4"] @@ -518,8 +516,10 @@ class LowBitLinear(nn.Linear): if self.training and x.requires_grad: result = MatMulLowBitCPU.apply(x, self.weight) else: + from bigdl.llm.utils.isa_checker import is_server, is_spr + # convert if necessary, and compute a linear result - if IS_SERVER and (not IS_SPR) and \ + if is_server() and (not is_spr()) and \ self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD: x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length) result = F.linear(x, x0_fp32) diff --git a/python/llm/src/bigdl/llm/utils/utils.py b/python/llm/src/bigdl/llm/utils/utils.py index 1bfde572..974bf4a1 100644 --- a/python/llm/src/bigdl/llm/utils/utils.py +++ b/python/llm/src/bigdl/llm/utils/utils.py @@ -16,25 +16,9 @@ import sys import pathlib -from bigdl.llm.utils.isa_checker import check_avx_vnni, check_avx2, check_avx512_vnni from bigdl.llm.utils.common import invalidInputError, invalidOperationError -def get_cpu_flags(): - flags = "" - if sys.platform != "win32": - if check_avx512_vnni(): - flags = "_avx512" - elif check_avx_vnni(): - flags = "_avx2" - else: - invalidOperationError(False, "Unsupported CPUFLAGS.") - else: - # flags = "_vnni" if check_avx_vnni() else "" - flags = "-api" - return flags - - def get_shared_lib_info(lib_base_name: str): # Determine the file extension based on the platform if sys.platform.startswith("linux") or sys.platform == "darwin":