LLM: basic api support for esimd fp16 (#9067)
* basic api support for fp16 * fix style * fix * fix error and style * fix style * meet code review * update based on comments
This commit is contained in:
parent
65373d2a8b
commit
f64257a093
4 changed files with 103 additions and 28 deletions
|
|
@ -31,7 +31,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
|
||||||
"asym_int5": 7, # q5_1 in ggml
|
"asym_int5": 7, # q5_1 in ggml
|
||||||
"sym_int8": 8, # q8_0 in ggml
|
"sym_int8": 8, # q8_0 in ggml
|
||||||
"nf4": 10,
|
"nf4": 10,
|
||||||
"nf3": 11}
|
"nf3": 11,
|
||||||
|
"fp16": 12}
|
||||||
|
|
||||||
_llama_quantize_type = {"q4_0": 2,
|
_llama_quantize_type = {"q4_0": 2,
|
||||||
"q4_1": 3,
|
"q4_1": 3,
|
||||||
|
|
@ -71,7 +72,7 @@ def quantize(input_path: str, output_path: str,
|
||||||
:param dtype: Quantization method which differs in the resulting model disk size and
|
:param dtype: Quantization method which differs in the resulting model disk size and
|
||||||
inference speed. Defalut to `q4_0`. Difference model family may support
|
inference speed. Defalut to `q4_0`. Difference model family may support
|
||||||
different types, now the supported list is:
|
different types, now the supported list is:
|
||||||
llama : "q4_0", "q4_1", "q4_2"
|
llama : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
|
||||||
bloom : "q4_0", "q4_1"
|
bloom : "q4_0", "q4_1"
|
||||||
gptneox : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
|
gptneox : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
|
||||||
starcoder : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
|
starcoder : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
|
||||||
|
|
|
||||||
|
|
@ -41,12 +41,13 @@ from accelerate import init_empty_weights
|
||||||
import warnings
|
import warnings
|
||||||
import transformers
|
import transformers
|
||||||
import importlib
|
import importlib
|
||||||
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from .utils import logger
|
from .utils import logger
|
||||||
|
|
||||||
|
|
||||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
current_key_name=None, convert_shape_only=False):
|
current_key_name=None, convert_shape_only=False):
|
||||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params
|
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear
|
||||||
has_been_replaced = False
|
has_been_replaced = False
|
||||||
|
|
||||||
for name, module in model.named_children():
|
for name, module in model.named_children():
|
||||||
|
|
@ -57,6 +58,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
# Check if the current key is not in the `modules_to_not_convert`
|
# Check if the current key is not in the `modules_to_not_convert`
|
||||||
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
|
new_linear = None
|
||||||
|
if qtype != ggml_tensor_qtype["fp16"]:
|
||||||
new_linear = LowBitLinear(
|
new_linear = LowBitLinear(
|
||||||
module.in_features,
|
module.in_features,
|
||||||
module.out_features,
|
module.out_features,
|
||||||
|
|
@ -73,7 +76,27 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
convert_shape_only=convert_shape_only,
|
convert_shape_only=convert_shape_only,
|
||||||
qtype=qtype).to(device_type)
|
qtype=qtype).to(device_type)
|
||||||
new_linear._parameters['weight'] = paramsLowBit
|
new_linear._parameters['weight'] = paramsLowBit
|
||||||
|
else:
|
||||||
|
# only support two size now
|
||||||
|
# may generalize to other sizes
|
||||||
|
if module.in_features in [4096, 11008]:
|
||||||
|
# esimd fp16 path
|
||||||
|
new_linear = FP16Linear(
|
||||||
|
module.in_features,
|
||||||
|
module.out_features,
|
||||||
|
qtype,
|
||||||
|
module.bias is not None,
|
||||||
|
)
|
||||||
|
device_type = module.weight.data.device.type
|
||||||
|
|
||||||
|
# convert here
|
||||||
|
m, n = module.weight.data.shape
|
||||||
|
trans_weight = module.weight.data.reshape(m//16, 16, n)
|
||||||
|
trans_weight = trans_weight.transpose(1, 2).contiguous()
|
||||||
|
new_linear._parameters['weight'] = nn.Parameter(trans_weight)
|
||||||
|
|
||||||
|
# fp16 may generalize to other sizes later
|
||||||
|
if new_linear is not None:
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
.to(device_type)
|
.to(device_type)
|
||||||
|
|
|
||||||
|
|
@ -378,3 +378,53 @@ class LowBitLinear(nn.Linear):
|
||||||
result += self.bias
|
result += self.bias
|
||||||
|
|
||||||
return result.to(x.dtype)
|
return result.to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class FP16Linear(nn.Linear):
|
||||||
|
def __init__(self, input_features, output_features, qtype, bias=True,
|
||||||
|
conver_to_half=True):
|
||||||
|
super().__init__(input_features, output_features, bias)
|
||||||
|
self.in_len = input_features
|
||||||
|
self.out_len = output_features
|
||||||
|
self.weight_shape = (self.out_len, self.in_len)
|
||||||
|
self.weight_length = self.out_len * self.in_len
|
||||||
|
self.qtype = qtype
|
||||||
|
self.conver_to_half = conver_to_half
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
|
x_shape = x.shape
|
||||||
|
x_2d = x.view(-1, x_shape[-1])
|
||||||
|
|
||||||
|
x0 = self.weight.data
|
||||||
|
# only work for GPU
|
||||||
|
invalidInputError(x0.device.type == "xpu",
|
||||||
|
"FP16 only works for GPU")
|
||||||
|
try:
|
||||||
|
import intel_extension_for_pytorch
|
||||||
|
import linear_fp16_esimd
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
invalidInputError(False,
|
||||||
|
"Please `pip install bigdl_core_xe` first.")
|
||||||
|
|
||||||
|
if x_2d.is_contiguous() is False:
|
||||||
|
x_2d = x_2d.contiguous()
|
||||||
|
|
||||||
|
if x_2d.shape[0] > 1:
|
||||||
|
# first token or batch size > 1, re-convert weight
|
||||||
|
original_weight = self.weight.data.transpose(1, 2)
|
||||||
|
original_weight = original_weight.reshape(self.out_len, self.in_len)
|
||||||
|
result = F.linear(x_2d, original_weight.contiguous())
|
||||||
|
del original_weight
|
||||||
|
else:
|
||||||
|
# rest token, use esimd optimization
|
||||||
|
result = linear_fp16_esimd.forward(x_2d, self.weight.data)
|
||||||
|
|
||||||
|
new_shape = x_shape[:-1] + (self.out_len,)
|
||||||
|
result = result.view(new_shape)
|
||||||
|
if self.bias is not None:
|
||||||
|
result += self.bias
|
||||||
|
|
||||||
|
return result.to(x.dtype)
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ class _BaseAutoModelClass:
|
||||||
:param load_in_4bit: boolean value, True means load linear's weight to symmetric int 4.
|
:param load_in_4bit: boolean value, True means load linear's weight to symmetric int 4.
|
||||||
Default to be False.
|
Default to be False.
|
||||||
:param load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5
|
:param load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5
|
||||||
or sym_int8. sym_int4 means symmetric int 4, asym_int4 means
|
, sym_int8 or fp16. sym_int4 means symmetric int 4, asym_int4 means
|
||||||
asymmetric int 4, etc. Relevant low bit optimizations will
|
asymmetric int 4, etc. Relevant low bit optimizations will
|
||||||
be applied to the model.
|
be applied to the model.
|
||||||
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
|
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
|
||||||
|
|
@ -104,8 +104,9 @@ class _BaseAutoModelClass:
|
||||||
from .convert import ggml_convert_low_bit
|
from .convert import ggml_convert_low_bit
|
||||||
invalidInputError(q_k in ggml_tensor_qtype,
|
invalidInputError(q_k in ggml_tensor_qtype,
|
||||||
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
||||||
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8 or fp16.")
|
||||||
qtype = ggml_tensor_qtype[q_k]
|
qtype = ggml_tensor_qtype[q_k]
|
||||||
|
|
||||||
# In case it needs a second try,
|
# In case it needs a second try,
|
||||||
# `from_pretrained`` may pop items out in dict
|
# `from_pretrained`` may pop items out in dict
|
||||||
# and lead to args missing.
|
# and lead to args missing.
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue