[LLM] support transformer int4 + amx int4 (#8838)
This commit is contained in:
parent
ddff7a6f05
commit
7429ea0606
2 changed files with 10 additions and 2 deletions
|
|
@ -51,10 +51,11 @@ from torch import Tensor, device, dtype, nn
|
|||
T = TypeVar("T", bound="torch.nn.Module")
|
||||
|
||||
import bigdl.llm.ggml.model.llama.llama_cpp as ggml
|
||||
from bigdl.llm.utils.isa_checker import is_server
|
||||
from bigdl.llm.utils.isa_checker import is_server, is_spr
|
||||
import ctypes
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
IS_SERVER = is_server()
|
||||
IS_SPR = is_spr()
|
||||
TORCH_LINEAR_THRESHOLD = 96
|
||||
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
|
||||
|
||||
|
|
@ -256,7 +257,8 @@ class LinearQuant(nn.Linear):
|
|||
else:
|
||||
# CPU logic
|
||||
# todo may need to set a different number on different platforms
|
||||
if IS_SERVER and self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
|
||||
if IS_SERVER and (not IS_SPR) and \
|
||||
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
|
||||
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
|
||||
result = F.linear(x, x0_fp32, self.bias)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -79,3 +79,9 @@ def check_avx512_vnni():
|
|||
|
||||
def is_server():
|
||||
return check_avx512_vnni()
|
||||
|
||||
|
||||
# todo: use cpuid to check SPR
|
||||
# note: now only SPR supports both avxvnni and avx512vnni
|
||||
def is_spr():
|
||||
return check_avx_vnni() and check_avx512_vnni()
|
||||
|
|
|
|||
Loading…
Reference in a new issue