diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index f8e82455..e23de58e 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -56,6 +56,7 @@ import sys _IS_VLLM_AVAILABLE = None _USE_VLLM = False _USE_VLLM_AWQ = False +_USE_VLLM_GPTQ = False _VLLM_VERSION = None @@ -144,7 +145,7 @@ def is_linear_module(module): is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM) if is_vllm_available(): # Only convert vllm modules - global _VLLM_VERSION, _USE_VLLM_AWQ + global _VLLM_VERSION, _USE_VLLM_AWQ, _USE_VLLM_GPTQ if _VLLM_VERSION is None: _VLLM_VERSION = get_package_version('vllm') from vllm.model_executor.layers.linear import ( @@ -186,6 +187,10 @@ def is_linear_module(module): and hasattr(module.quant_method, "quant_config") and module.quant_method.quant_config.get_name() == "awq"): _USE_VLLM_AWQ = True + if (not _USE_VLLM_GPTQ + and hasattr(module.quant_method, "quant_config") + and module.quant_method.quant_config.get_name() == "gptq"): + _USE_VLLM_GPTQ = True invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not" " support linear layers with skip_bias_add argument") if isinstance(module, RowParallelLinear) and tp_size >= 2: @@ -291,16 +296,23 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype, return new_linear -def convert_vllm_awq(module): +def convert_vllm_awq_or_gptq(module, gptq=False, act_order=False): from ipex_llm.transformers.low_bit_linear import get_block_size Q4_1 = get_block_size("asym_int4") scales = module.scales - wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7], - dtype=torch.int32) * 4).unsqueeze(0) # vLLM only supports load 4-bits model, so this has been checked - bits = 4 + if gptq: + bits = module.quant_method.quant_config.weight_bits + wf = (torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], + dtype=torch.int32) * 4).unsqueeze(0) + else: + bits = 4 + wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7], + dtype=torch.int32) * 4).unsqueeze(0) group_size = module.quant_method.quant_config.group_size + if int(group_size) % Q4_1 != 0: + invalidInputError(False, (f"group_size:{group_size} must be divisible by "f"{Q4_1}.")) zeros = torch.bitwise_right_shift( torch.unsqueeze(module.qzeros, 2).expand(-1, -1, 32 // bits), @@ -309,13 +321,28 @@ def convert_vllm_awq(module): g_id_map = None + if gptq: + zeros = zeros + 1 zeros = zeros.reshape(scales.shape) - weight = torch.bitwise_right_shift( - torch.unsqueeze(module.qweight, 2).expand(-1, -1, 32 // bits), - wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8) - weight = torch.bitwise_and(weight, (2 ** bits) - 1) - weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2]) + if not gptq: + weight = torch.bitwise_right_shift( + torch.unsqueeze(module.qweight, 2).expand(-1, -1, 32 // bits), + wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8) + weight = torch.bitwise_and(weight, (2 ** bits) - 1) + weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2]) + else: + weight = torch.bitwise_right_shift( + torch.unsqueeze(module.qweight, 1).expand(-1, 32 // bits, -1), + wf.unsqueeze(-1)).to(torch.int8) + weight = torch.bitwise_and(weight, (2 ** bits) - 1) + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + + if act_order: + invalidInputError(module.g_idx.shape[0] == weight.shape[0], + "g_idx and weight shape mismatch") + _, g_id_map = torch.sort(module.g_idx) + weight = weight[g_id_map, :] # convert weight to ggml format weight = weight.reshape(weight.shape[0]//group_size, group_size, weight.shape[1]) @@ -455,7 +482,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, FP16Linear, BF16Linear from ipex_llm.transformers.embedding import CPUEmbedding, DiskEmbedding, LowBitEmbedding has_been_replaced = False - global _USE_VLLM_AWQ + global _USE_VLLM_AWQ, _USE_VLLM_GPTQ for name, module in model.named_children(): is_linear, linear_args = is_linear_module(module) @@ -523,7 +550,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if has_bias: new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ .to(device) - elif _USE_VLLM_AWQ: + elif _USE_VLLM_AWQ or _USE_VLLM_GPTQ: # User load an AWQ quantized model from vLLM from ipex_llm.transformers.low_bit_linear import vLLMLowBitLinear from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -571,7 +598,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, device = module.qweight.data.device invalidInputError(device.type != "meta", "converting from meta device is not supported") - weight, g_idx_map = convert_vllm_awq(module) + weight, g_idx_map = convert_vllm_awq_or_gptq(module, gptq=_USE_VLLM_GPTQ, + act_order=act_order) if act_order: new_linear.g_idx_map = g_idx_map # Copy the weights