[LLM] support transformer int4 + amx int4 (#8838)

This commit is contained in:
Yishuo Wang 2023-08-29 17:27:18 +08:00 committed by GitHub
parent ddff7a6f05
commit 7429ea0606
2 changed files with 10 additions and 2 deletions

View file

@ -51,10 +51,11 @@ from torch import Tensor, device, dtype, nn
T = TypeVar("T", bound="torch.nn.Module") T = TypeVar("T", bound="torch.nn.Module")
import bigdl.llm.ggml.model.llama.llama_cpp as ggml import bigdl.llm.ggml.model.llama.llama_cpp as ggml
from bigdl.llm.utils.isa_checker import is_server from bigdl.llm.utils.isa_checker import is_server, is_spr
import ctypes import ctypes
from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.ggml.quantize import ggml_tensor_qtype
IS_SERVER = is_server() IS_SERVER = is_server()
IS_SPR = is_spr()
TORCH_LINEAR_THRESHOLD = 96 TORCH_LINEAR_THRESHOLD = 96
SYM_INT4 = ggml_tensor_qtype["sym_int4"] SYM_INT4 = ggml_tensor_qtype["sym_int4"]
@ -256,7 +257,8 @@ class LinearQuant(nn.Linear):
else: else:
# CPU logic # CPU logic
# todo may need to set a different number on different platforms # todo may need to set a different number on different platforms
if IS_SERVER and self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD: if IS_SERVER and (not IS_SPR) and \
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length) x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
result = F.linear(x, x0_fp32, self.bias) result = F.linear(x, x0_fp32, self.bias)
else: else:

View file

@ -79,3 +79,9 @@ def check_avx512_vnni():
def is_server(): def is_server():
return check_avx512_vnni() return check_avx512_vnni()
# todo: use cpuid to check SPR
# note: now only SPR supports both avxvnni and avx512vnni
def is_spr():
return check_avx_vnni() and check_avx512_vnni()