LLM: Rename low bit layer (#8875)
* rename lowbit --------- Co-authored-by: leonardozcm <leonardozcm@gmail.com>
This commit is contained in:
parent
74a2c2ddf5
commit
95271f10e0
5 changed files with 24 additions and 23 deletions
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from .transformers import ggml_convert_quant
|
from .transformers import ggml_convert_low_bit
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
@ -34,4 +34,4 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True):
|
||||||
f"Unknown load_in_low_bit value: {low_bit}, expected:"
|
f"Unknown load_in_low_bit value: {low_bit}, expected:"
|
||||||
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
||||||
qtype = ggml_tensor_qtype[low_bit]
|
qtype = ggml_tensor_qtype[low_bit]
|
||||||
return ggml_convert_quant(model, qtype=qtype, optimize_model=optimize_llm)
|
return ggml_convert_low_bit(model, qtype=qtype, optimize_model=optimize_llm)
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from .convert import ggml_convert_quant
|
|
||||||
|
from .convert import ggml_convert_low_bit
|
||||||
from .model import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, \
|
from .model import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, \
|
||||||
AutoModelForSpeechSeq2Seq, AutoModelForQuestionAnswering, \
|
AutoModelForSpeechSeq2Seq, AutoModelForQuestionAnswering, \
|
||||||
AutoModelForSequenceClassification, AutoModelForMaskedLM, \
|
AutoModelForSequenceClassification, AutoModelForMaskedLM, \
|
||||||
|
|
|
||||||
|
|
@ -43,9 +43,9 @@ import transformers
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
|
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
current_key_name=None):
|
current_key_name=None):
|
||||||
from bigdl.llm.transformers.linear_quant import LinearQuant, FP4Params
|
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params
|
||||||
has_been_replaced = False
|
has_been_replaced = False
|
||||||
|
|
||||||
for name, module in model.named_children():
|
for name, module in model.named_children():
|
||||||
|
|
@ -56,7 +56,7 @@ def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
|
||||||
# Check if the current key is not in the `modules_to_not_convert`
|
# Check if the current key is not in the `modules_to_not_convert`
|
||||||
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
new_linear = LinearQuant(
|
new_linear = LowBitLinear(
|
||||||
module.in_features,
|
module.in_features,
|
||||||
module.out_features,
|
module.out_features,
|
||||||
qtype,
|
qtype,
|
||||||
|
|
@ -65,12 +65,12 @@ def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
|
||||||
|
|
||||||
device_type = module.weight.data.device.type
|
device_type = module.weight.data.device.type
|
||||||
# Copy the weights
|
# Copy the weights
|
||||||
paramsQuant = FP4Params(data=module.weight.data,
|
paramsLowBit = FP4Params(data=module.weight.data,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
quantized=False,
|
quantized=False,
|
||||||
_shape=None,
|
_shape=None,
|
||||||
qtype=qtype).to(device_type)
|
qtype=qtype).to(device_type)
|
||||||
new_linear._parameters['weight'] = paramsQuant
|
new_linear._parameters['weight'] = paramsLowBit
|
||||||
|
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
|
|
@ -85,7 +85,7 @@ def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
|
||||||
|
|
||||||
# Remove the last key for recursion
|
# Remove the last key for recursion
|
||||||
if len(list(module.children())) > 0:
|
if len(list(module.children())) > 0:
|
||||||
_, _flag = _replace_with_quant_linear(
|
_, _flag = _replace_with_low_bit_linear(
|
||||||
module,
|
module,
|
||||||
qtype,
|
qtype,
|
||||||
modules_to_not_convert,
|
modules_to_not_convert,
|
||||||
|
|
@ -95,9 +95,9 @@ def _replace_with_quant_linear(model, qtype, modules_to_not_convert=None,
|
||||||
return model, has_been_replaced
|
return model, has_been_replaced
|
||||||
|
|
||||||
|
|
||||||
def ggml_convert_quant(model, qtype, optimize_model=True, device="cpu"):
|
def ggml_convert_low_bit(model, qtype, optimize_model=True, device="cpu"):
|
||||||
modules_to_not_convert = [] # ["lm_head"]
|
modules_to_not_convert = [] # ["lm_head"]
|
||||||
model, has_been_replaced = _replace_with_quant_linear(
|
model, has_been_replaced = _replace_with_low_bit_linear(
|
||||||
model, qtype, modules_to_not_convert, None
|
model, qtype, modules_to_not_convert, None
|
||||||
)
|
)
|
||||||
if not has_been_replaced:
|
if not has_been_replaced:
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ TORCH_LINEAR_THRESHOLD = 96
|
||||||
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
|
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
|
||||||
|
|
||||||
|
|
||||||
def ggml_convert_quant(tensor: torch.Tensor, qtype: int, device=None):
|
def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, device=None):
|
||||||
QK = ggml.ggml_qk_size(qtype)
|
QK = ggml.ggml_qk_size(qtype)
|
||||||
block_size_in_bytes = ggml.ggml_type_size(qtype)
|
block_size_in_bytes = ggml.ggml_type_size(qtype)
|
||||||
|
|
||||||
|
|
@ -123,7 +123,7 @@ class FP4Params(torch.nn.Parameter):
|
||||||
def quantize(self, device=None):
|
def quantize(self, device=None):
|
||||||
if not self.quantized:
|
if not self.quantized:
|
||||||
w = self.data.contiguous().float()
|
w = self.data.contiguous().float()
|
||||||
w_quantized = ggml_convert_quant(w, self.qtype,
|
w_quantized = ggml_convert_qtype(w, self.qtype,
|
||||||
device=device)
|
device=device)
|
||||||
self.data = w_quantized
|
self.data = w_quantized
|
||||||
self.quantized = True
|
self.quantized = True
|
||||||
|
|
@ -212,7 +212,7 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor,
|
||||||
return result_t
|
return result_t
|
||||||
|
|
||||||
|
|
||||||
class LinearQuant(nn.Linear):
|
class LowBitLinear(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, qtype, bias=True):
|
def __init__(self, input_features, output_features, qtype, bias=True):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
self.weight = FP4Params(self.weight.data,
|
self.weight = FP4Params(self.weight.data,
|
||||||
|
|
@ -98,7 +98,7 @@ class _BaseAutoModelClass:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_convert(cls, q_k, optimize_model, *args, **kwargs):
|
def load_convert(cls, q_k, optimize_model, *args, **kwargs):
|
||||||
from .convert import ggml_convert_quant
|
from .convert import ggml_convert_low_bit
|
||||||
invalidInputError(q_k in ggml_tensor_qtype,
|
invalidInputError(q_k in ggml_tensor_qtype,
|
||||||
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
f"Unknown load_in_low_bit value: {q_k}, expected:"
|
||||||
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
|
||||||
|
|
@ -117,7 +117,7 @@ class _BaseAutoModelClass:
|
||||||
model = cls.HF_Model.from_pretrained(*_args, **_kwargs)
|
model = cls.HF_Model.from_pretrained(*_args, **_kwargs)
|
||||||
model.config.update({"bigdl_lcmu_enabled": False})
|
model.config.update({"bigdl_lcmu_enabled": False})
|
||||||
model = model.to("cpu")
|
model = model.to("cpu")
|
||||||
model = ggml_convert_quant(model, qtype, optimize_model)
|
model = ggml_convert_low_bit(model, qtype, optimize_model)
|
||||||
model.config.update({"bigdl_transformers_low_bit": q_k})
|
model.config.update({"bigdl_transformers_low_bit": q_k})
|
||||||
|
|
||||||
# add save_low_bit to pretrained model dynamically
|
# add save_low_bit to pretrained model dynamically
|
||||||
|
|
@ -139,7 +139,7 @@ class _BaseAutoModelClass:
|
||||||
from transformers.generation.configuration_utils import GenerationConfig
|
from transformers.generation.configuration_utils import GenerationConfig
|
||||||
from transformers.models.auto.auto_factory import _get_model_class
|
from transformers.models.auto.auto_factory import _get_model_class
|
||||||
from accelerate.big_modeling import init_empty_weights
|
from accelerate.big_modeling import init_empty_weights
|
||||||
from .convert import ggml_convert_quant
|
from .convert import ggml_convert_low_bit
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
@ -252,7 +252,7 @@ class _BaseAutoModelClass:
|
||||||
|
|
||||||
# Loading args may differ based on their usage
|
# Loading args may differ based on their usage
|
||||||
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
|
quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
|
||||||
model = ggml_convert_quant(model, qtype, optimize_model, device=quant_device)
|
model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device)
|
||||||
|
|
||||||
if is_sharded:
|
if is_sharded:
|
||||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue