fix cpuinfo error (#9793)
This commit is contained in:
parent
7ed9538b9f
commit
7d9f6c6efc
2 changed files with 4 additions and 20 deletions
|
|
@ -55,11 +55,9 @@ from bigdl.llm.transformers.utils import get_autocast_dtype
|
|||
T = TypeVar("T", bound="torch.nn.Module")
|
||||
|
||||
import bigdl.llm.ggml.model.llama.llama_cpp as ggml
|
||||
from bigdl.llm.utils.isa_checker import is_server, is_spr
|
||||
import ctypes
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
IS_SERVER = is_server()
|
||||
IS_SPR = is_spr()
|
||||
|
||||
TORCH_LINEAR_THRESHOLD = int(os.getenv("BIGDL_LLM_LINEAR_THRESHOLD", "512"))
|
||||
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
|
||||
ASYM_INT4 = ggml_tensor_qtype["asym_int4"]
|
||||
|
|
@ -518,8 +516,10 @@ class LowBitLinear(nn.Linear):
|
|||
if self.training and x.requires_grad:
|
||||
result = MatMulLowBitCPU.apply(x, self.weight)
|
||||
else:
|
||||
from bigdl.llm.utils.isa_checker import is_server, is_spr
|
||||
|
||||
# convert if necessary, and compute a linear result
|
||||
if IS_SERVER and (not IS_SPR) and \
|
||||
if is_server() and (not is_spr()) and \
|
||||
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
|
||||
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
|
||||
result = F.linear(x, x0_fp32)
|
||||
|
|
|
|||
|
|
@ -16,25 +16,9 @@
|
|||
|
||||
import sys
|
||||
import pathlib
|
||||
from bigdl.llm.utils.isa_checker import check_avx_vnni, check_avx2, check_avx512_vnni
|
||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
||||
|
||||
|
||||
def get_cpu_flags():
|
||||
flags = ""
|
||||
if sys.platform != "win32":
|
||||
if check_avx512_vnni():
|
||||
flags = "_avx512"
|
||||
elif check_avx_vnni():
|
||||
flags = "_avx2"
|
||||
else:
|
||||
invalidOperationError(False, "Unsupported CPUFLAGS.")
|
||||
else:
|
||||
# flags = "_vnni" if check_avx_vnni() else ""
|
||||
flags = "-api"
|
||||
return flags
|
||||
|
||||
|
||||
def get_shared_lib_info(lib_base_name: str):
|
||||
# Determine the file extension based on the platform
|
||||
if sys.platform.startswith("linux") or sys.platform == "darwin":
|
||||
|
|
|
|||
Loading…
Reference in a new issue