LLM: basic api support for esimd fp16 (#9067)
* basic api support for fp16 * fix style * fix * fix error and style * fix style * meet code review * update based on comments
This commit is contained in:
parent
65373d2a8b
commit
f64257a093
4 changed files with 103 additions and 28 deletions
|
|
@ -31,7 +31,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
|
|||
"asym_int5": 7, # q5_1 in ggml
|
||||
"sym_int8": 8, # q8_0 in ggml
|
||||
"nf4": 10,
|
||||
"nf3": 11}
|
||||
"nf3": 11,
|
||||
"fp16": 12}
|
||||
|
||||
_llama_quantize_type = {"q4_0": 2,
|
||||
"q4_1": 3,
|
||||
|
|
@ -71,7 +72,7 @@ def quantize(input_path: str, output_path: str,
|
|||
:param dtype: Quantization method which differs in the resulting model disk size and
|
||||
inference speed. Defalut to `q4_0`. Difference model family may support
|
||||
different types, now the supported list is:
|
||||
llama : "q4_0", "q4_1", "q4_2"
|
||||
llama : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
|
||||
bloom : "q4_0", "q4_1"
|
||||
gptneox : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
|
||||
starcoder : "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"
|
||||
|
|
|
|||
|
|
@ -41,12 +41,13 @@ from accelerate import init_empty_weights
|
|||
import warnings
|
||||
import transformers
|
||||
import importlib
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from .utils import logger
|
||||
|
||||
|
||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||
current_key_name=None, convert_shape_only=False):
|
||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params
|
||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear
|
||||
has_been_replaced = False
|
||||
|
||||
for name, module in model.named_children():
|
||||
|
|
@ -57,6 +58,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
||||
with init_empty_weights():
|
||||
new_linear = None
|
||||
if qtype != ggml_tensor_qtype["fp16"]:
|
||||
new_linear = LowBitLinear(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
|
|
@ -73,7 +76,27 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
convert_shape_only=convert_shape_only,
|
||||
qtype=qtype).to(device_type)
|
||||
new_linear._parameters['weight'] = paramsLowBit
|
||||
else:
|
||||
# only support two size now
|
||||
# may generalize to other sizes
|
||||
if module.in_features in [4096, 11008]:
|
||||
# esimd fp16 path
|
||||
new_linear = FP16Linear(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
qtype,
|
||||
module.bias is not None,
|
||||
)
|
||||
device_type = module.weight.data.device.type
|
||||
|
||||
# convert here
|
||||
m, n = module.weight.data.shape
|
||||
trans_weight = module.weight.data.reshape(m//16, 16, n)
|
||||
trans_weight = trans_weight.transpose(1, 2).contiguous()
|
||||
new_linear._parameters['weight'] = nn.Parameter(trans_weight)
|
||||
|
||||
# fp16 may generalize to other sizes later
|
||||
if new_linear is not None:
|
||||
if module.bias is not None:
|
||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||
.to(device_type)
|
||||
|
|
|
|||
|
|
@ -378,3 +378,53 @@ class LowBitLinear(nn.Linear):
|
|||
result += self.bias
|
||||
|
||||
return result.to(x.dtype)
|
||||
|
||||
|
||||
class FP16Linear(nn.Linear):
|
||||
def __init__(self, input_features, output_features, qtype, bias=True,
|
||||
conver_to_half=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.in_len = input_features
|
||||
self.out_len = output_features
|
||||
self.weight_shape = (self.out_len, self.in_len)
|
||||
self.weight_length = self.out_len * self.in_len
|
||||
self.qtype = qtype
|
||||
self.conver_to_half = conver_to_half
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
x_shape = x.shape
|
||||
x_2d = x.view(-1, x_shape[-1])
|
||||
|
||||
x0 = self.weight.data
|
||||
# only work for GPU
|
||||
invalidInputError(x0.device.type == "xpu",
|
||||
"FP16 only works for GPU")
|
||||
try:
|
||||
import intel_extension_for_pytorch
|
||||
import linear_fp16_esimd
|
||||
except ModuleNotFoundError:
|
||||
invalidInputError(False,
|
||||
"Please `pip install bigdl_core_xe` first.")
|
||||
|
||||
if x_2d.is_contiguous() is False:
|
||||
x_2d = x_2d.contiguous()
|
||||
|
||||
if x_2d.shape[0] > 1:
|
||||
# first token or batch size > 1, re-convert weight
|
||||
original_weight = self.weight.data.transpose(1, 2)
|
||||
original_weight = original_weight.reshape(self.out_len, self.in_len)
|
||||
result = F.linear(x_2d, original_weight.contiguous())
|
||||
del original_weight
|
||||
else:
|
||||
# rest token, use esimd optimization
|
||||
result = linear_fp16_esimd.forward(x_2d, self.weight.data)
|
||||
|
||||
new_shape = x_shape[:-1] + (self.out_len,)
|
||||
result = result.view(new_shape)
|
||||
if self.bias is not None:
|
||||
result += self.bias
|
||||
|
||||
return result.to(x.dtype)
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class _BaseAutoModelClass:
|
|||
:param load_in_4bit: boolean value, True means load linear's weight to symmetric int 4.
|
||||
Default to be False.
|
||||
:param load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5
|
||||
or sym_int8. sym_int4 means symmetric int 4, asym_int4 means
|
||||
, sym_int8 or fp16. sym_int4 means symmetric int 4, asym_int4 means
|
||||
asymmetric int 4, etc. Relevant low bit optimizations will
|
||||
be applied to the model.
|
||||
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
|
||||
|
|
@ -104,8 +104,9 @@ class _BaseAutoModelClass:
|
|||
from .convert import ggml_convert_low_bit
|
||||
invalidInputError(q_k in ggml_tensor_qtype,
|
||||
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
||||
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
||||
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8 or fp16.")
|
||||
qtype = ggml_tensor_qtype[q_k]
|
||||
|
||||
# In case it needs a second try,
|
||||
# `from_pretrained`` may pop items out in dict
|
||||
# and lead to args missing.
|
||||
|
|
|
|||
Loading…
Reference in a new issue