diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 8a6dc24f..14e8fbba 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -189,7 +189,7 @@ def convert_gptq(module, awq=False, llm_awq=False): def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, - current_key_name=None, convert_shape_only=False, + convert_shape_only=False, cpu_embedding=False, prefix_name='', imatrix_data=None, embedding_qtype=None, model_type=None, torch_dtype=torch.float32, @@ -200,133 +200,131 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, has_been_replaced = False for name, module in model.named_children(): - if current_key_name is None: - current_key_name = [] - is_linear, linear_args = is_linear_module(module) full_module_name = prefix_name + '.' + name if prefix_name != '' else name - if is_linear and name not in modules_to_not_convert and \ - full_module_name not in modules_to_not_convert: - # Check if the current key is not in the `modules_to_not_convert` - if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and - not isinstance(module, LowBitLinear)): - in_features, out_features, mp_group = linear_args - optimize_lm_head = False - if name == "lm_head": - if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD", - None) == "1": - optimize_lm_head = True - with init_empty_weights(): - new_linear = None - is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld) - is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM) - is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ - if is_gptq or is_awq: - has_bias = module.bias is not None and module.bias.abs().sum() != 0 - new_linear = LowBitLinear( - in_features, - out_features, - qtype=qtype, - bias=has_bias, - mp_group=mp_group, - enable_xetla=enable_xetla, - optimize_lm_head=optimize_lm_head - ) - device = module.qweight.data.device - invalidInputError(device.type != "meta", - "converting from meta device is not supported") - # Copy the weights - paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq, - llm_awq=is_llm_awq), - requires_grad=False, - quantized=True, - _shape=(out_features, in_features), - convert_shape_only=convert_shape_only, - qtype=qtype, - enable_xetla=enable_xetla).to(device) - new_linear._parameters['weight'] = paramsLowBit - if has_bias: - new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ - .to(device) - elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]: - if in_features % 64 != 0: - # now our kernel requires in_features is a multiple of 64 - continue - new_linear = LowBitLinear( - in_features, - out_features, - qtype, - module.bias is not None, - mp_group=mp_group, - enable_xetla=enable_xetla, - optimize_lm_head=optimize_lm_head - ) - cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype, - full_module_name, - imatrix_data, - model_type) - device = module.weight.data.device - # Copy the weights - paramsLowBit = FP4Params(data=module.weight.data, - requires_grad=False, - quantized=False, - _shape=None, - convert_shape_only=convert_shape_only, - qtype=cur_qtype, - imatrix=cur_imatrix, - in_features=in_features, - enable_xetla=enable_xetla).to(device) - new_linear._parameters['weight'] = paramsLowBit - if module.bias is not None: - new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ - .to(device) - elif qtype == ggml_tensor_qtype["fp16"]: - module.to(torch.float16) - new_linear = FP16Linear( - in_features, - out_features, - module.bias is not None, - mp_group=mp_group, - optimize_lm_head=optimize_lm_head - ) - device = module.weight.data.device - from bigdl.llm.transformers.utils import get_ipex_version - if get_ipex_version() < "2.1.10+xpu": - new_linear._parameters['weight'] = nn.Parameter(module.weight) - else: - # only from 2.1, ipex provides matmul_bias_out - # so we need to transpose weight - new_weight = module.weight.transpose(0, 1).contiguous() - new_linear._parameters['weight'] = nn.Parameter(new_weight) - new_linear.weight_type = 2 - if module.bias is not None: - new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ - .to(device) - elif qtype == ggml_tensor_qtype["bf16"]: - module.to(torch.bfloat16) - new_linear = BF16Linear( - in_features, - out_features, - module.bias is not None, - mp_group=mp_group, - optimize_lm_head=optimize_lm_head - ) - device = module.weight.data.device - # convert here + + # use sub-string to match, it may match `10` if user only pass a number like `0` + if any(key in full_module_name for key in modules_to_not_convert): + continue + + if is_linear and not isinstance(module, LowBitLinear): + in_features, out_features, mp_group = linear_args + optimize_lm_head = False + if name == "lm_head": + if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD", + None) == "1": + optimize_lm_head = True + with init_empty_weights(): + new_linear = None + is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld) + is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM) + is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ + if is_gptq or is_awq: + has_bias = module.bias is not None and module.bias.abs().sum() != 0 + new_linear = LowBitLinear( + in_features, + out_features, + qtype=qtype, + bias=has_bias, + mp_group=mp_group, + enable_xetla=enable_xetla, + optimize_lm_head=optimize_lm_head + ) + device = module.qweight.data.device + invalidInputError(device.type != "meta", + "converting from meta device is not supported") + # Copy the weights + paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq, + llm_awq=is_llm_awq), + requires_grad=False, + quantized=True, + _shape=(out_features, in_features), + convert_shape_only=convert_shape_only, + qtype=qtype, + enable_xetla=enable_xetla).to(device) + new_linear._parameters['weight'] = paramsLowBit + if has_bias: + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device) + elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]: + if in_features % 64 != 0: + # now our kernel requires in_features is a multiple of 64 + continue + new_linear = LowBitLinear( + in_features, + out_features, + qtype, + module.bias is not None, + mp_group=mp_group, + enable_xetla=enable_xetla, + optimize_lm_head=optimize_lm_head + ) + cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype, + full_module_name, + imatrix_data, + model_type) + device = module.weight.data.device + # Copy the weights + paramsLowBit = FP4Params(data=module.weight.data, + requires_grad=False, + quantized=False, + _shape=None, + convert_shape_only=convert_shape_only, + qtype=cur_qtype, + imatrix=cur_imatrix, + in_features=in_features, + enable_xetla=enable_xetla).to(device) + new_linear._parameters['weight'] = paramsLowBit + if module.bias is not None: + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device) + elif qtype == ggml_tensor_qtype["fp16"]: + module.to(torch.float16) + new_linear = FP16Linear( + in_features, + out_features, + module.bias is not None, + mp_group=mp_group, + optimize_lm_head=optimize_lm_head + ) + device = module.weight.data.device + from bigdl.llm.transformers.utils import get_ipex_version + if get_ipex_version() < "2.1.10+xpu": new_linear._parameters['weight'] = nn.Parameter(module.weight) - if module.bias is not None: - new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ - .to(device) + else: + # only from 2.1, ipex provides matmul_bias_out + # so we need to transpose weight + new_weight = module.weight.transpose(0, 1).contiguous() + new_linear._parameters['weight'] = nn.Parameter(new_weight) + new_linear.weight_type = 2 + if module.bias is not None: + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device) + elif qtype == ggml_tensor_qtype["bf16"]: + module.to(torch.bfloat16) + new_linear = BF16Linear( + in_features, + out_features, + module.bias is not None, + mp_group=mp_group, + optimize_lm_head=optimize_lm_head + ) + device = module.weight.data.device + # convert here + new_linear._parameters['weight'] = nn.Parameter(module.weight) + if module.bias is not None: + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device) - if new_linear is not None: - if not module.training: - new_linear.eval() - model._modules[name] = new_linear - has_been_replaced = True - # Force requires grad to False to avoid unexpected errors - model._modules[name].requires_grad_(False) + if new_linear is not None: + if not module.training: + new_linear.eval() + model._modules[name] = new_linear + has_been_replaced = True + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) - module.weight = None + module.weight = None elif cpu_embedding and type(module) == nn.Embedding: # skip user-defined Embedding layer model._modules[name] = LLMEmbedding( @@ -375,7 +373,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, module, qtype, modules_to_not_convert, - current_key_name, convert_shape_only, cpu_embedding, prefix_name=prefix_name + '.' + name if prefix_name != '' else name, @@ -664,7 +661,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True, model_type = None model, has_been_replaced = _replace_with_low_bit_linear( model, qtype, modules_to_not_convert, - None, convert_shape_only, cpu_embedding, + convert_shape_only, cpu_embedding, imatrix_data=imatrix_data, embedding_qtype=embedding_qtype, model_type=model_type,