Fix modules_not_to_convert argument (#10483)
This commit is contained in:
parent
cbe24cc7e6
commit
cfdf8ad496
1 changed files with 122 additions and 125 deletions
|
|
@ -189,7 +189,7 @@ def convert_gptq(module, awq=False, llm_awq=False):
|
||||||
|
|
||||||
|
|
||||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
current_key_name=None, convert_shape_only=False,
|
convert_shape_only=False,
|
||||||
cpu_embedding=False, prefix_name='',
|
cpu_embedding=False, prefix_name='',
|
||||||
imatrix_data=None, embedding_qtype=None,
|
imatrix_data=None, embedding_qtype=None,
|
||||||
model_type=None, torch_dtype=torch.float32,
|
model_type=None, torch_dtype=torch.float32,
|
||||||
|
|
@ -200,133 +200,131 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
has_been_replaced = False
|
has_been_replaced = False
|
||||||
|
|
||||||
for name, module in model.named_children():
|
for name, module in model.named_children():
|
||||||
if current_key_name is None:
|
|
||||||
current_key_name = []
|
|
||||||
|
|
||||||
is_linear, linear_args = is_linear_module(module)
|
is_linear, linear_args = is_linear_module(module)
|
||||||
full_module_name = prefix_name + '.' + name if prefix_name != '' else name
|
full_module_name = prefix_name + '.' + name if prefix_name != '' else name
|
||||||
if is_linear and name not in modules_to_not_convert and \
|
|
||||||
full_module_name not in modules_to_not_convert:
|
# use sub-string to match, it may match `10` if user only pass a number like `0`
|
||||||
# Check if the current key is not in the `modules_to_not_convert`
|
if any(key in full_module_name for key in modules_to_not_convert):
|
||||||
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
|
continue
|
||||||
not isinstance(module, LowBitLinear)):
|
|
||||||
in_features, out_features, mp_group = linear_args
|
if is_linear and not isinstance(module, LowBitLinear):
|
||||||
optimize_lm_head = False
|
in_features, out_features, mp_group = linear_args
|
||||||
if name == "lm_head":
|
optimize_lm_head = False
|
||||||
if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD",
|
if name == "lm_head":
|
||||||
None) == "1":
|
if model_type in ["gptj", "llama"] and os.environ.get("BIGDL_OPTIMIZE_LM_HEAD",
|
||||||
optimize_lm_head = True
|
None) == "1":
|
||||||
with init_empty_weights():
|
optimize_lm_head = True
|
||||||
new_linear = None
|
with init_empty_weights():
|
||||||
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
|
new_linear = None
|
||||||
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
is_gptq = is_auto_gptq_available() and isinstance(module, QuantLinearCudaOld)
|
||||||
is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
|
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
||||||
if is_gptq or is_awq:
|
is_llm_awq = is_awq and module.backend == AwqBackendPackingMethod.LLMAWQ
|
||||||
has_bias = module.bias is not None and module.bias.abs().sum() != 0
|
if is_gptq or is_awq:
|
||||||
new_linear = LowBitLinear(
|
has_bias = module.bias is not None and module.bias.abs().sum() != 0
|
||||||
in_features,
|
new_linear = LowBitLinear(
|
||||||
out_features,
|
in_features,
|
||||||
qtype=qtype,
|
out_features,
|
||||||
bias=has_bias,
|
qtype=qtype,
|
||||||
mp_group=mp_group,
|
bias=has_bias,
|
||||||
enable_xetla=enable_xetla,
|
mp_group=mp_group,
|
||||||
optimize_lm_head=optimize_lm_head
|
enable_xetla=enable_xetla,
|
||||||
)
|
optimize_lm_head=optimize_lm_head
|
||||||
device = module.qweight.data.device
|
)
|
||||||
invalidInputError(device.type != "meta",
|
device = module.qweight.data.device
|
||||||
"converting from meta device is not supported")
|
invalidInputError(device.type != "meta",
|
||||||
# Copy the weights
|
"converting from meta device is not supported")
|
||||||
paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq,
|
# Copy the weights
|
||||||
llm_awq=is_llm_awq),
|
paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq,
|
||||||
requires_grad=False,
|
llm_awq=is_llm_awq),
|
||||||
quantized=True,
|
requires_grad=False,
|
||||||
_shape=(out_features, in_features),
|
quantized=True,
|
||||||
convert_shape_only=convert_shape_only,
|
_shape=(out_features, in_features),
|
||||||
qtype=qtype,
|
convert_shape_only=convert_shape_only,
|
||||||
enable_xetla=enable_xetla).to(device)
|
qtype=qtype,
|
||||||
new_linear._parameters['weight'] = paramsLowBit
|
enable_xetla=enable_xetla).to(device)
|
||||||
if has_bias:
|
new_linear._parameters['weight'] = paramsLowBit
|
||||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
if has_bias:
|
||||||
.to(device)
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
|
.to(device)
|
||||||
if in_features % 64 != 0:
|
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
|
||||||
# now our kernel requires in_features is a multiple of 64
|
if in_features % 64 != 0:
|
||||||
continue
|
# now our kernel requires in_features is a multiple of 64
|
||||||
new_linear = LowBitLinear(
|
continue
|
||||||
in_features,
|
new_linear = LowBitLinear(
|
||||||
out_features,
|
in_features,
|
||||||
qtype,
|
out_features,
|
||||||
module.bias is not None,
|
qtype,
|
||||||
mp_group=mp_group,
|
module.bias is not None,
|
||||||
enable_xetla=enable_xetla,
|
mp_group=mp_group,
|
||||||
optimize_lm_head=optimize_lm_head
|
enable_xetla=enable_xetla,
|
||||||
)
|
optimize_lm_head=optimize_lm_head
|
||||||
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
|
)
|
||||||
full_module_name,
|
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
|
||||||
imatrix_data,
|
full_module_name,
|
||||||
model_type)
|
imatrix_data,
|
||||||
device = module.weight.data.device
|
model_type)
|
||||||
# Copy the weights
|
device = module.weight.data.device
|
||||||
paramsLowBit = FP4Params(data=module.weight.data,
|
# Copy the weights
|
||||||
requires_grad=False,
|
paramsLowBit = FP4Params(data=module.weight.data,
|
||||||
quantized=False,
|
requires_grad=False,
|
||||||
_shape=None,
|
quantized=False,
|
||||||
convert_shape_only=convert_shape_only,
|
_shape=None,
|
||||||
qtype=cur_qtype,
|
convert_shape_only=convert_shape_only,
|
||||||
imatrix=cur_imatrix,
|
qtype=cur_qtype,
|
||||||
in_features=in_features,
|
imatrix=cur_imatrix,
|
||||||
enable_xetla=enable_xetla).to(device)
|
in_features=in_features,
|
||||||
new_linear._parameters['weight'] = paramsLowBit
|
enable_xetla=enable_xetla).to(device)
|
||||||
if module.bias is not None:
|
new_linear._parameters['weight'] = paramsLowBit
|
||||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
if module.bias is not None:
|
||||||
.to(device)
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
elif qtype == ggml_tensor_qtype["fp16"]:
|
.to(device)
|
||||||
module.to(torch.float16)
|
elif qtype == ggml_tensor_qtype["fp16"]:
|
||||||
new_linear = FP16Linear(
|
module.to(torch.float16)
|
||||||
in_features,
|
new_linear = FP16Linear(
|
||||||
out_features,
|
in_features,
|
||||||
module.bias is not None,
|
out_features,
|
||||||
mp_group=mp_group,
|
module.bias is not None,
|
||||||
optimize_lm_head=optimize_lm_head
|
mp_group=mp_group,
|
||||||
)
|
optimize_lm_head=optimize_lm_head
|
||||||
device = module.weight.data.device
|
)
|
||||||
from bigdl.llm.transformers.utils import get_ipex_version
|
device = module.weight.data.device
|
||||||
if get_ipex_version() < "2.1.10+xpu":
|
from bigdl.llm.transformers.utils import get_ipex_version
|
||||||
new_linear._parameters['weight'] = nn.Parameter(module.weight)
|
if get_ipex_version() < "2.1.10+xpu":
|
||||||
else:
|
|
||||||
# only from 2.1, ipex provides matmul_bias_out
|
|
||||||
# so we need to transpose weight
|
|
||||||
new_weight = module.weight.transpose(0, 1).contiguous()
|
|
||||||
new_linear._parameters['weight'] = nn.Parameter(new_weight)
|
|
||||||
new_linear.weight_type = 2
|
|
||||||
if module.bias is not None:
|
|
||||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
|
||||||
.to(device)
|
|
||||||
elif qtype == ggml_tensor_qtype["bf16"]:
|
|
||||||
module.to(torch.bfloat16)
|
|
||||||
new_linear = BF16Linear(
|
|
||||||
in_features,
|
|
||||||
out_features,
|
|
||||||
module.bias is not None,
|
|
||||||
mp_group=mp_group,
|
|
||||||
optimize_lm_head=optimize_lm_head
|
|
||||||
)
|
|
||||||
device = module.weight.data.device
|
|
||||||
# convert here
|
|
||||||
new_linear._parameters['weight'] = nn.Parameter(module.weight)
|
new_linear._parameters['weight'] = nn.Parameter(module.weight)
|
||||||
if module.bias is not None:
|
else:
|
||||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
# only from 2.1, ipex provides matmul_bias_out
|
||||||
.to(device)
|
# so we need to transpose weight
|
||||||
|
new_weight = module.weight.transpose(0, 1).contiguous()
|
||||||
|
new_linear._parameters['weight'] = nn.Parameter(new_weight)
|
||||||
|
new_linear.weight_type = 2
|
||||||
|
if module.bias is not None:
|
||||||
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
|
.to(device)
|
||||||
|
elif qtype == ggml_tensor_qtype["bf16"]:
|
||||||
|
module.to(torch.bfloat16)
|
||||||
|
new_linear = BF16Linear(
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
module.bias is not None,
|
||||||
|
mp_group=mp_group,
|
||||||
|
optimize_lm_head=optimize_lm_head
|
||||||
|
)
|
||||||
|
device = module.weight.data.device
|
||||||
|
# convert here
|
||||||
|
new_linear._parameters['weight'] = nn.Parameter(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
|
.to(device)
|
||||||
|
|
||||||
if new_linear is not None:
|
if new_linear is not None:
|
||||||
if not module.training:
|
if not module.training:
|
||||||
new_linear.eval()
|
new_linear.eval()
|
||||||
model._modules[name] = new_linear
|
model._modules[name] = new_linear
|
||||||
has_been_replaced = True
|
has_been_replaced = True
|
||||||
# Force requires grad to False to avoid unexpected errors
|
# Force requires grad to False to avoid unexpected errors
|
||||||
model._modules[name].requires_grad_(False)
|
model._modules[name].requires_grad_(False)
|
||||||
|
|
||||||
module.weight = None
|
module.weight = None
|
||||||
elif cpu_embedding and type(module) == nn.Embedding:
|
elif cpu_embedding and type(module) == nn.Embedding:
|
||||||
# skip user-defined Embedding layer
|
# skip user-defined Embedding layer
|
||||||
model._modules[name] = LLMEmbedding(
|
model._modules[name] = LLMEmbedding(
|
||||||
|
|
@ -375,7 +373,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
module,
|
module,
|
||||||
qtype,
|
qtype,
|
||||||
modules_to_not_convert,
|
modules_to_not_convert,
|
||||||
current_key_name,
|
|
||||||
convert_shape_only,
|
convert_shape_only,
|
||||||
cpu_embedding,
|
cpu_embedding,
|
||||||
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
|
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
|
||||||
|
|
@ -664,7 +661,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
|
||||||
model_type = None
|
model_type = None
|
||||||
model, has_been_replaced = _replace_with_low_bit_linear(
|
model, has_been_replaced = _replace_with_low_bit_linear(
|
||||||
model, qtype, modules_to_not_convert,
|
model, qtype, modules_to_not_convert,
|
||||||
None, convert_shape_only, cpu_embedding,
|
convert_shape_only, cpu_embedding,
|
||||||
imatrix_data=imatrix_data,
|
imatrix_data=imatrix_data,
|
||||||
embedding_qtype=embedding_qtype,
|
embedding_qtype=embedding_qtype,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue