Fix modules_not_to_convert argument (#10483)

This commit is contained in:
Yishuo Wang 2024-03-20 17:47:03 +08:00 committed by GitHub
parent cbe24cc7e6
commit cfdf8ad496

View file

@ -189,7 +189,7 @@ def convert_gptq(module, awq=False, llm_awq=False):
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name=None, convert_shape_only=False, convert_shape_only=False,
cpu_embedding=False, prefix_name='', cpu_embedding=False, prefix_name='',
imatrix_data=None, embedding_qtype=None, imatrix_data=None, embedding_qtype=None,
model_type=None, torch_dtype=torch.float32, model_type=None, torch_dtype=torch.float32,
@ -200,133 +200,131 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
has_been_replaced = False has_been_replaced = False
for name, module in model.named_children(): for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
is_linear, linear_args = is_linear_module(module) is_linear, linear_args = is_linear_module(module)
full_module_name = prefix_name + '.' + name if prefix_name != '' else name full_module_name = prefix_name + '.' + name if prefix_name != '' else name
if is_linear and name not in modules_to_not_convert and \
full_module_name not in modules_to_not_convert: # use sub-string to match, it may match `10` if user only pass a number like `0`
# Check if the current key is not in the `modules_to_not_convert` if any(key in full_module_name for key in modules_to_not_convert):
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and continue
not isinstance(module, LowBitLinear)):
in_features, out_features, mp_group = linear_args if is_linear and not isinstance(module, LowBitLinear):
optimize_lm_head = False in_features, out_features, mp_group = linear_args
if name == "lm_head": optimize_lm_head = False
if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD", if name == "lm_head":
None) == "1": if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD",
optimize_lm_head = True None) == "1":
with init_empty_weights(): optimize_lm_head = True
new_linear = None with init_empty_weights():
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld) new_linear = None
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM) is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
if is_gptq or is_awq: is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
has_bias = module.bias is not None and module.bias.abs().sum() != 0 if is_gptq or is_awq:
new_linear = LowBitLinear( has_bias = module.bias is not None and module.bias.abs().sum() != 0
in_features, new_linear = LowBitLinear(
out_features, in_features,
qtype=qtype, out_features,
bias=has_bias, qtype=qtype,
mp_group=mp_group, bias=has_bias,
enable_xetla=enable_xetla, mp_group=mp_group,
optimize_lm_head=optimize_lm_head enable_xetla=enable_xetla,
) optimize_lm_head=optimize_lm_head
device = module.qweight.data.device )
invalidInputError(device.type != "meta", device = module.qweight.data.device
"converting from meta device is not supported") invalidInputError(device.type != "meta",
# Copy the weights "converting from meta device is not supported")
paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq, # Copy the weights
llm_awq=is_llm_awq), paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq,
requires_grad=False, llm_awq=is_llm_awq),
quantized=True, requires_grad=False,
_shape=(out_features, in_features), quantized=True,
convert_shape_only=convert_shape_only, _shape=(out_features, in_features),
qtype=qtype, convert_shape_only=convert_shape_only,
enable_xetla=enable_xetla).to(device) qtype=qtype,
new_linear._parameters['weight'] = paramsLowBit enable_xetla=enable_xetla).to(device)
if has_bias: new_linear._parameters['weight'] = paramsLowBit
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ if has_bias:
.to(device) new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]: .to(device)
if in_features % 64 != 0: elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
# now our kernel requires in_features is a multiple of 64 if in_features % 64 != 0:
continue # now our kernel requires in_features is a multiple of 64
new_linear = LowBitLinear( continue
in_features, new_linear = LowBitLinear(
out_features, in_features,
qtype, out_features,
module.bias is not None, qtype,
mp_group=mp_group, module.bias is not None,
enable_xetla=enable_xetla, mp_group=mp_group,
optimize_lm_head=optimize_lm_head enable_xetla=enable_xetla,
) optimize_lm_head=optimize_lm_head
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype, )
full_module_name, cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
imatrix_data, full_module_name,
model_type) imatrix_data,
device = module.weight.data.device model_type)
# Copy the weights device = module.weight.data.device
paramsLowBit = FP4Params(data=module.weight.data, # Copy the weights
requires_grad=False, paramsLowBit = FP4Params(data=module.weight.data,
quantized=False, requires_grad=False,
_shape=None, quantized=False,
convert_shape_only=convert_shape_only, _shape=None,
qtype=cur_qtype, convert_shape_only=convert_shape_only,
imatrix=cur_imatrix, qtype=cur_qtype,
in_features=in_features, imatrix=cur_imatrix,
enable_xetla=enable_xetla).to(device) in_features=in_features,
new_linear._parameters['weight'] = paramsLowBit enable_xetla=enable_xetla).to(device)
if module.bias is not None: new_linear._parameters['weight'] = paramsLowBit
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ if module.bias is not None:
.to(device) new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
elif qtype == ggml_tensor_qtype["fp16"]: .to(device)
module.to(torch.float16) elif qtype == ggml_tensor_qtype["fp16"]:
new_linear = FP16Linear( module.to(torch.float16)
in_features, new_linear = FP16Linear(
out_features, in_features,
module.bias is not None, out_features,
mp_group=mp_group, module.bias is not None,
optimize_lm_head=optimize_lm_head mp_group=mp_group,
) optimize_lm_head=optimize_lm_head
device = module.weight.data.device )
from bigdl.llm.transformers.utils import get_ipex_version device = module.weight.data.device
if get_ipex_version() < "2.1.10+xpu": from bigdl.llm.transformers.utils import get_ipex_version
new_linear._parameters['weight'] = nn.Parameter(module.weight) if get_ipex_version() < "2.1.10+xpu":
else:
# only from 2.1, ipex provides matmul_bias_out
# so we need to transpose weight
new_weight = module.weight.transpose(0, 1).contiguous()
new_linear._parameters['weight'] = nn.Parameter(new_weight)
new_linear.weight_type = 2
if module.bias is not None:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device)
elif qtype == ggml_tensor_qtype["bf16"]:
module.to(torch.bfloat16)
new_linear = BF16Linear(
in_features,
out_features,
module.bias is not None,
mp_group=mp_group,
optimize_lm_head=optimize_lm_head
)
device = module.weight.data.device
# convert here
new_linear._parameters['weight'] = nn.Parameter(module.weight) new_linear._parameters['weight'] = nn.Parameter(module.weight)
if module.bias is not None: else:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ # only from 2.1, ipex provides matmul_bias_out
.to(device) # so we need to transpose weight
new_weight = module.weight.transpose(0, 1).contiguous()
new_linear._parameters['weight'] = nn.Parameter(new_weight)
new_linear.weight_type = 2
if module.bias is not None:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device)
elif qtype == ggml_tensor_qtype["bf16"]:
module.to(torch.bfloat16)
new_linear = BF16Linear(
in_features,
out_features,
module.bias is not None,
mp_group=mp_group,
optimize_lm_head=optimize_lm_head
)
device = module.weight.data.device
# convert here
new_linear._parameters['weight'] = nn.Parameter(module.weight)
if module.bias is not None:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device)
if new_linear is not None: if new_linear is not None:
if not module.training: if not module.training:
new_linear.eval() new_linear.eval()
model._modules[name] = new_linear model._modules[name] = new_linear
has_been_replaced = True has_been_replaced = True
# Force requires grad to False to avoid unexpected errors # Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False) model._modules[name].requires_grad_(False)
module.weight = None module.weight = None
elif cpu_embedding and type(module) == nn.Embedding: elif cpu_embedding and type(module) == nn.Embedding:
# skip user-defined Embedding layer # skip user-defined Embedding layer
model._modules[name] = LLMEmbedding( model._modules[name] = LLMEmbedding(
@ -375,7 +373,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
module, module,
qtype, qtype,
modules_to_not_convert, modules_to_not_convert,
current_key_name,
convert_shape_only, convert_shape_only,
cpu_embedding, cpu_embedding,
prefix_name=prefix_name + '.' + name if prefix_name != '' else name, prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
@ -664,7 +661,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
model_type = None model_type = None
model, has_been_replaced = _replace_with_low_bit_linear( model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert, model, qtype, modules_to_not_convert,
None, convert_shape_only, cpu_embedding, convert_shape_only, cpu_embedding,
imatrix_data=imatrix_data, imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype, embedding_qtype=embedding_qtype,
model_type=model_type, model_type=model_type,