[LLM] support transformer int4 + amx int4 (#8838)
This commit is contained in:
parent
ddff7a6f05
commit
7429ea0606
2 changed files with 10 additions and 2 deletions
|
|
@ -51,10 +51,11 @@ from torch import Tensor, device, dtype, nn
|
||||||
T = TypeVar("T", bound="torch.nn.Module")
|
T = TypeVar("T", bound="torch.nn.Module")
|
||||||
|
|
||||||
import bigdl.llm.ggml.model.llama.llama_cpp as ggml
|
import bigdl.llm.ggml.model.llama.llama_cpp as ggml
|
||||||
from bigdl.llm.utils.isa_checker import is_server
|
from bigdl.llm.utils.isa_checker import is_server, is_spr
|
||||||
import ctypes
|
import ctypes
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
IS_SERVER = is_server()
|
IS_SERVER = is_server()
|
||||||
|
IS_SPR = is_spr()
|
||||||
TORCH_LINEAR_THRESHOLD = 96
|
TORCH_LINEAR_THRESHOLD = 96
|
||||||
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
|
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
|
||||||
|
|
||||||
|
|
@ -256,7 +257,8 @@ class LinearQuant(nn.Linear):
|
||||||
else:
|
else:
|
||||||
# CPU logic
|
# CPU logic
|
||||||
# todo may need to set a different number on different platforms
|
# todo may need to set a different number on different platforms
|
||||||
if IS_SERVER and self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
|
if IS_SERVER and (not IS_SPR) and \
|
||||||
|
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
|
||||||
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
|
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
|
||||||
result = F.linear(x, x0_fp32, self.bias)
|
result = F.linear(x, x0_fp32, self.bias)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -79,3 +79,9 @@ def check_avx512_vnni():
|
||||||
|
|
||||||
def is_server():
|
def is_server():
|
||||||
return check_avx512_vnni()
|
return check_avx512_vnni()
|
||||||
|
|
||||||
|
|
||||||
|
# todo: use cpuid to check SPR
|
||||||
|
# note: now only SPR supports both avxvnni and avx512vnni
|
||||||
|
def is_spr():
|
||||||
|
return check_avx_vnni() and check_avx512_vnni()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue