LLM : Support embedding quantization (only q2k now) (#10170)
* basic logic added * basic support * support save&load, update mixed strategy * fix style * use int8 for lm_head * add check for xpu
This commit is contained in:
parent
eca69a6022
commit
3288acb8de
4 changed files with 128 additions and 31 deletions
|
|
@ -191,10 +191,10 @@ def convert_gptq(module, awq=False, llm_awq=False):
|
||||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
current_key_name=None, convert_shape_only=False,
|
current_key_name=None, convert_shape_only=False,
|
||||||
cpu_embedding=False, prefix_name='',
|
cpu_embedding=False, prefix_name='',
|
||||||
imatrix_data=None):
|
imatrix_data=None, embedding_qtype=None):
|
||||||
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
|
||||||
FP16Linear, BF16Linear
|
FP16Linear, BF16Linear
|
||||||
from bigdl.llm.transformers.embedding import LLMEmbedding
|
from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
|
||||||
has_been_replaced = False
|
has_been_replaced = False
|
||||||
|
|
||||||
for name, module in model.named_children():
|
for name, module in model.named_children():
|
||||||
|
|
@ -323,6 +323,32 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
sparse=module.sparse,
|
sparse=module.sparse,
|
||||||
_weight=module.weight.data,
|
_weight=module.weight.data,
|
||||||
)
|
)
|
||||||
|
elif type(module) == nn.Embedding and embedding_qtype is not None:
|
||||||
|
q_embedding = LowBitEmbedding(
|
||||||
|
num_embeddings=module.num_embeddings,
|
||||||
|
embedding_dim=module.embedding_dim,
|
||||||
|
padding_idx=module.padding_idx,
|
||||||
|
max_norm=module.max_norm,
|
||||||
|
norm_type=module.norm_type,
|
||||||
|
scale_grad_by_freq=module.scale_grad_by_freq,
|
||||||
|
sparse=module.sparse,
|
||||||
|
_weight=module.weight.data,
|
||||||
|
qtype=embedding_qtype,
|
||||||
|
)
|
||||||
|
device = module.weight.data.device
|
||||||
|
# Copy the weights
|
||||||
|
paramsLowBit = FP4Params(data=module.weight.data,
|
||||||
|
requires_grad=False,
|
||||||
|
quantized=False,
|
||||||
|
_shape=None,
|
||||||
|
convert_shape_only=convert_shape_only,
|
||||||
|
qtype=embedding_qtype,
|
||||||
|
in_features=module.embedding_dim).to(device)
|
||||||
|
q_embedding._parameters['weight'] = paramsLowBit
|
||||||
|
model._modules[name] = q_embedding
|
||||||
|
# Force requires grad to False to avoid unexpected errors
|
||||||
|
model._modules[name].requires_grad_(False)
|
||||||
|
module.weight = None
|
||||||
|
|
||||||
# Remove the last key for recursion
|
# Remove the last key for recursion
|
||||||
if len(list(module.children())) > 0:
|
if len(list(module.children())) > 0:
|
||||||
|
|
@ -334,7 +360,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
convert_shape_only,
|
convert_shape_only,
|
||||||
cpu_embedding,
|
cpu_embedding,
|
||||||
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
|
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
|
||||||
imatrix_data=imatrix_data
|
imatrix_data=imatrix_data,
|
||||||
|
embedding_qtype=embedding_qtype
|
||||||
)
|
)
|
||||||
has_been_replaced = _flag or has_been_replaced
|
has_been_replaced = _flag or has_been_replaced
|
||||||
return model, has_been_replaced
|
return model, has_been_replaced
|
||||||
|
|
@ -512,7 +539,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
convert_shape_only=False, device="cpu",
|
convert_shape_only=False, device="cpu",
|
||||||
modules_to_not_convert=None, cpu_embedding=False,
|
modules_to_not_convert=None, cpu_embedding=False,
|
||||||
lightweight_bmm=False, torch_dtype="auto",
|
lightweight_bmm=False, torch_dtype="auto",
|
||||||
imatrix_data=None):
|
imatrix_data=None, embedding_qtype=None):
|
||||||
logger.info(f"Converting the current model to "
|
logger.info(f"Converting the current model to "
|
||||||
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
|
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
|
||||||
f"format......")
|
f"format......")
|
||||||
|
|
@ -535,6 +562,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
model, qtype, modules_to_not_convert,
|
model, qtype, modules_to_not_convert,
|
||||||
None, convert_shape_only, cpu_embedding,
|
None, convert_shape_only, cpu_embedding,
|
||||||
imatrix_data=imatrix_data,
|
imatrix_data=imatrix_data,
|
||||||
|
embedding_qtype=embedding_qtype
|
||||||
)
|
)
|
||||||
if not has_been_replaced:
|
if not has_been_replaced:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,8 @@ from torch import Tensor
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from bigdl.llm.transformers.low_bit_linear import FP4Params
|
||||||
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
# To prevent insufficient available memory when moving embedding from XPU back to CPU,
|
# To prevent insufficient available memory when moving embedding from XPU back to CPU,
|
||||||
|
|
@ -72,3 +74,39 @@ class LLMEmbedding(torch.nn.Embedding):
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
return super().forward(x.to('cpu')).to(x.device)
|
return super().forward(x.to('cpu')).to(x.device)
|
||||||
|
|
||||||
|
|
||||||
|
class LowBitEmbedding(torch.nn.Embedding):
|
||||||
|
def __init__(self,
|
||||||
|
num_embeddings: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
padding_idx: Optional[int] = None,
|
||||||
|
max_norm: Optional[float] = None,
|
||||||
|
norm_type: float = 2.,
|
||||||
|
scale_grad_by_freq: bool = False,
|
||||||
|
sparse: bool = False,
|
||||||
|
_weight: Optional[Tensor] = None,
|
||||||
|
_freeze: bool = False,
|
||||||
|
device=None, dtype=None,
|
||||||
|
qtype=None) -> None:
|
||||||
|
super().__init__(num_embeddings, embedding_dim, padding_idx,
|
||||||
|
max_norm, norm_type, scale_grad_by_freq, sparse,
|
||||||
|
_weight, device, dtype)
|
||||||
|
self.weight = FP4Params(self.weight.data,
|
||||||
|
requires_grad=False,
|
||||||
|
quantized=False, _shape=None, qtype=qtype)
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
invalidInputError(x.device.type == "xpu",
|
||||||
|
"`LowBitEmbedding` only supports GPU now.")
|
||||||
|
try:
|
||||||
|
import intel_extension_for_pytorch
|
||||||
|
import linear_q4_0
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
invalidInputError(False,
|
||||||
|
"Please `pip install bigdl_core_xe` first.")
|
||||||
|
|
||||||
|
result = linear_q4_0.dequantize_rows(x.contiguous(), self.weight.data,
|
||||||
|
self.weight.qtype, self.embedding_dim)
|
||||||
|
return result
|
||||||
|
|
|
||||||
|
|
@ -129,6 +129,8 @@ class _BaseAutoModelClass:
|
||||||
added to llama.cpp.
|
added to llama.cpp.
|
||||||
:param model_hub: str value, options are ``'huggingface'`` and ``'modelscope'``,
|
:param model_hub: str value, options are ``'huggingface'`` and ``'modelscope'``,
|
||||||
specify the model hub. Default to be ``'huggingface'``.
|
specify the model hub. Default to be ``'huggingface'``.
|
||||||
|
:param embedding_qtype: str value, options are ``'q2_k'`` now. Default to be None.
|
||||||
|
Relevant low bit optimizations will be applied to nn.Embedding layer.
|
||||||
:return: a model instance
|
:return: a model instance
|
||||||
"""
|
"""
|
||||||
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
|
||||||
|
|
@ -159,6 +161,7 @@ class _BaseAutoModelClass:
|
||||||
user_quantization_config = kwargs.pop("quantization_config", None)
|
user_quantization_config = kwargs.pop("quantization_config", None)
|
||||||
speculative = kwargs.pop("speculative", False)
|
speculative = kwargs.pop("speculative", False)
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
|
embedding_qtype = kwargs.pop("embedding_qtype", None)
|
||||||
|
|
||||||
if user_quantization_config is not None and \
|
if user_quantization_config is not None and \
|
||||||
"BitsAndBytesConfig" in str(user_quantization_config.__class__):
|
"BitsAndBytesConfig" in str(user_quantization_config.__class__):
|
||||||
|
|
@ -278,9 +281,15 @@ class _BaseAutoModelClass:
|
||||||
if q_k in ["iq2_xxs", "iq2_xs"]:
|
if q_k in ["iq2_xxs", "iq2_xs"]:
|
||||||
invalidInputError(imatrix_file is not None,
|
invalidInputError(imatrix_file is not None,
|
||||||
"For iq2_xxs and iq2_xs quantization, imatrix is needed.")
|
"For iq2_xxs and iq2_xs quantization, imatrix is needed.")
|
||||||
|
cpu_embedding = kwargs.get("cpu_embedding", False)
|
||||||
|
# for 2bit, default use embedding_quantization
|
||||||
|
if q_k in ["iq2_xxs", "iq2_xs", "q2_k"] and not cpu_embedding and \
|
||||||
|
embedding_qtype is None:
|
||||||
|
embedding_qtype = "q2_k"
|
||||||
if imatrix_file is not None:
|
if imatrix_file is not None:
|
||||||
imatrix_data = load_imatrix_data(imatrix_file)
|
imatrix_data = load_imatrix_data(imatrix_file)
|
||||||
kwargs['imatrix_data'] = imatrix_data
|
kwargs["imatrix_data"] = imatrix_data
|
||||||
|
kwargs["embedding_qtype"] = embedding_qtype
|
||||||
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
|
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
|
||||||
|
|
||||||
if speculative:
|
if speculative:
|
||||||
|
|
@ -339,6 +348,9 @@ class _BaseAutoModelClass:
|
||||||
lightweight_bmm = kwargs.pop("lightweight_bmm", False)
|
lightweight_bmm = kwargs.pop("lightweight_bmm", False)
|
||||||
quant_config = kwargs.pop("quantization_config", None)
|
quant_config = kwargs.pop("quantization_config", None)
|
||||||
imatrix_data = kwargs.pop("imatrix_data", None)
|
imatrix_data = kwargs.pop("imatrix_data", None)
|
||||||
|
embedding_qtype = kwargs.pop("embedding_qtype", None)
|
||||||
|
if embedding_qtype is not None:
|
||||||
|
embedding_qtype = ggml_tensor_qtype[embedding_qtype]
|
||||||
_args = copy.deepcopy(args)
|
_args = copy.deepcopy(args)
|
||||||
_kwargs = copy.deepcopy(kwargs)
|
_kwargs = copy.deepcopy(kwargs)
|
||||||
awq_config = None
|
awq_config = None
|
||||||
|
|
@ -400,7 +412,8 @@ class _BaseAutoModelClass:
|
||||||
modules_to_not_convert=modules_to_not_convert,
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
|
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
|
||||||
torch_dtype=kwargs.get("torch_dtype", 'auto'),
|
torch_dtype=kwargs.get("torch_dtype", 'auto'),
|
||||||
imatrix_data=imatrix_data)
|
imatrix_data=imatrix_data,
|
||||||
|
embedding_qtype=embedding_qtype)
|
||||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||||
|
|
||||||
# enable tie_word_embeddings for MPT
|
# enable tie_word_embeddings for MPT
|
||||||
|
|
@ -469,6 +482,7 @@ class _BaseAutoModelClass:
|
||||||
offload_folder = kwargs.pop("offload_folder", None)
|
offload_folder = kwargs.pop("offload_folder", None)
|
||||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||||
torch_dtype = kwargs.pop("torch_dtype", "auto")
|
torch_dtype = kwargs.pop("torch_dtype", "auto")
|
||||||
|
embedding_qtype = kwargs.pop("embedding_qtype", None)
|
||||||
sharded_metadata = None
|
sharded_metadata = None
|
||||||
|
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path)
|
||||||
|
|
@ -488,6 +502,10 @@ class _BaseAutoModelClass:
|
||||||
optimize_model = kwargs.pop("optimize_model", True)
|
optimize_model = kwargs.pop("optimize_model", True)
|
||||||
|
|
||||||
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
|
qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
|
||||||
|
if bigdl_transformers_low_bit in ["iq2_xxs", "iq2_xs", "q2_k"] and not cpu_embedding:
|
||||||
|
embedding_qtype = "q2_k"
|
||||||
|
if embedding_qtype is not None:
|
||||||
|
embedding_qtype = ggml_tensor_qtype[embedding_qtype]
|
||||||
|
|
||||||
has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map
|
has_remote_code = hasattr(config, "auto_map") and cls.HF_Model.__name__ in config.auto_map
|
||||||
has_local_code = type(config) in cls.HF_Model._model_mapping.keys()
|
has_local_code = type(config) in cls.HF_Model._model_mapping.keys()
|
||||||
|
|
@ -572,7 +590,8 @@ class _BaseAutoModelClass:
|
||||||
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
|
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
|
||||||
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
|
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
|
||||||
modules_to_not_convert=modules_to_not_convert,
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm)
|
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
|
||||||
|
embedding_qtype=embedding_qtype)
|
||||||
|
|
||||||
if is_sharded:
|
if is_sharded:
|
||||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||||
|
|
|
||||||
|
|
@ -224,16 +224,10 @@ def load_imatrix_data(imatrix_file):
|
||||||
return imatrix_data
|
return imatrix_data
|
||||||
|
|
||||||
|
|
||||||
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
|
def module_name_process(full_module_name):
|
||||||
if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"],
|
|
||||||
ggml_tensor_qtype["q2_k"]] and imatrix_data is not None:
|
|
||||||
# For quantization which needs importance matrix
|
|
||||||
# module name preprocess
|
|
||||||
# full name maybe model.layers.31.self_attn.o_proj
|
# full name maybe model.layers.31.self_attn.o_proj
|
||||||
# TODO: just consider llama/mistral here
|
|
||||||
# TODO: how to better aligned and generalize
|
# TODO: how to better aligned and generalize
|
||||||
module_name = full_module_name.split('.')
|
module_name = full_module_name.split('.')
|
||||||
cur_qtype = qtype
|
|
||||||
if len(module_name) == 5:
|
if len(module_name) == 5:
|
||||||
layer = module_name[2]
|
layer = module_name[2]
|
||||||
cur_module = module_name[-1][:-5]
|
cur_module = module_name[-1][:-5]
|
||||||
|
|
@ -242,19 +236,37 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
|
||||||
new_module_name = module_name[0]
|
new_module_name = module_name[0]
|
||||||
layer = None
|
layer = None
|
||||||
cur_module = None
|
cur_module = None
|
||||||
|
return new_module_name, layer, cur_module
|
||||||
|
|
||||||
|
|
||||||
|
def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
|
||||||
|
cur_qtype = qtype
|
||||||
|
if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]:
|
||||||
|
# For quantization which needs importance matrix
|
||||||
|
new_module_name, layer, cur_module = module_name_process(full_module_name)
|
||||||
|
# custom mixed quantization strategy
|
||||||
|
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
|
||||||
|
cur_qtype = ggml_tensor_qtype['q2_k']
|
||||||
if imatrix_data is not None and new_module_name in imatrix_data:
|
if imatrix_data is not None and new_module_name in imatrix_data:
|
||||||
cur_imatrix = imatrix_data[new_module_name]
|
cur_imatrix = imatrix_data[new_module_name]
|
||||||
# custom mixed quantization strategy
|
|
||||||
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \
|
|
||||||
or new_module_name == 'lm_head':
|
|
||||||
cur_qtype = ggml_tensor_qtype['sym_int4']
|
|
||||||
else:
|
else:
|
||||||
|
# if no imatrix is available, use fp8 for lm_head
|
||||||
cur_imatrix = None
|
cur_imatrix = None
|
||||||
# custom mixed quantization strategy
|
if new_module_name == 'lm_head':
|
||||||
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \
|
cur_qtype = ggml_tensor_qtype['sym_int8']
|
||||||
or new_module_name == 'lm_head':
|
|
||||||
cur_qtype = ggml_tensor_qtype['sym_int4']
|
|
||||||
return cur_qtype, cur_imatrix
|
return cur_qtype, cur_imatrix
|
||||||
|
elif qtype == ggml_tensor_qtype["q2_k"]:
|
||||||
|
new_module_name, layer, cur_module = module_name_process(full_module_name)
|
||||||
|
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
|
||||||
|
# TODO: q2_k need others k-quants type here
|
||||||
|
cur_qtype = ggml_tensor_qtype['q2_k']
|
||||||
|
if imatrix_data is not None and new_module_name in imatrix_data:
|
||||||
|
cur_imatrix = imatrix_data[new_module_name]
|
||||||
|
else:
|
||||||
|
# if no imatrix is available, use fp8 for lm_head
|
||||||
|
cur_imatrix = None
|
||||||
|
if new_module_name == 'lm_head':
|
||||||
|
cur_qtype = ggml_tensor_qtype['sym_int8']
|
||||||
else:
|
else:
|
||||||
return qtype, None
|
return qtype, None
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue