Enable vllm load gptq model (#12083)

* enable vllm load gptq model

* update

* update

* update

* update style
This commit is contained in:
Wang, Jian4 2024-09-18 14:41:00 +08:00 committed by GitHub
parent c2774e1a43
commit 40e463c66b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -56,6 +56,7 @@ import sys
_IS_VLLM_AVAILABLE = None _IS_VLLM_AVAILABLE = None
_USE_VLLM = False _USE_VLLM = False
_USE_VLLM_AWQ = False _USE_VLLM_AWQ = False
_USE_VLLM_GPTQ = False
_VLLM_VERSION = None _VLLM_VERSION = None
@ -144,7 +145,7 @@ def is_linear_module(module):
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM) is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
if is_vllm_available(): if is_vllm_available():
# Only convert vllm modules # Only convert vllm modules
global _VLLM_VERSION, _USE_VLLM_AWQ global _VLLM_VERSION, _USE_VLLM_AWQ, _USE_VLLM_GPTQ
if _VLLM_VERSION is None: if _VLLM_VERSION is None:
_VLLM_VERSION = get_package_version('vllm') _VLLM_VERSION = get_package_version('vllm')
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
@ -186,6 +187,10 @@ def is_linear_module(module):
and hasattr(module.quant_method, "quant_config") and hasattr(module.quant_method, "quant_config")
and module.quant_method.quant_config.get_name() == "awq"): and module.quant_method.quant_config.get_name() == "awq"):
_USE_VLLM_AWQ = True _USE_VLLM_AWQ = True
if (not _USE_VLLM_GPTQ
and hasattr(module.quant_method, "quant_config")
and module.quant_method.quant_config.get_name() == "gptq"):
_USE_VLLM_GPTQ = True
invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not" invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not"
" support linear layers with skip_bias_add argument") " support linear layers with skip_bias_add argument")
if isinstance(module, RowParallelLinear) and tp_size >= 2: if isinstance(module, RowParallelLinear) and tp_size >= 2:
@ -291,16 +296,23 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
return new_linear return new_linear
def convert_vllm_awq(module): def convert_vllm_awq_or_gptq(module, gptq=False, act_order=False):
from ipex_llm.transformers.low_bit_linear import get_block_size from ipex_llm.transformers.low_bit_linear import get_block_size
Q4_1 = get_block_size("asym_int4") Q4_1 = get_block_size("asym_int4")
scales = module.scales scales = module.scales
# vLLM only supports load 4-bits model, so this has been checked
if gptq:
bits = module.quant_method.quant_config.weight_bits
wf = (torch.tensor([0, 1, 2, 3, 4, 5, 6, 7],
dtype=torch.int32) * 4).unsqueeze(0)
else:
bits = 4
wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7], wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7],
dtype=torch.int32) * 4).unsqueeze(0) dtype=torch.int32) * 4).unsqueeze(0)
# vLLM only supports load 4-bits model, so this has been checked
bits = 4
group_size = module.quant_method.quant_config.group_size group_size = module.quant_method.quant_config.group_size
if int(group_size) % Q4_1 != 0:
invalidInputError(False, (f"group_size:{group_size} must be divisible by "f"{Q4_1}."))
zeros = torch.bitwise_right_shift( zeros = torch.bitwise_right_shift(
torch.unsqueeze(module.qzeros, 2).expand(-1, -1, 32 // bits), torch.unsqueeze(module.qzeros, 2).expand(-1, -1, 32 // bits),
@ -309,13 +321,28 @@ def convert_vllm_awq(module):
g_id_map = None g_id_map = None
if gptq:
zeros = zeros + 1
zeros = zeros.reshape(scales.shape) zeros = zeros.reshape(scales.shape)
if not gptq:
weight = torch.bitwise_right_shift( weight = torch.bitwise_right_shift(
torch.unsqueeze(module.qweight, 2).expand(-1, -1, 32 // bits), torch.unsqueeze(module.qweight, 2).expand(-1, -1, 32 // bits),
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8) wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
weight = torch.bitwise_and(weight, (2 ** bits) - 1) weight = torch.bitwise_and(weight, (2 ** bits) - 1)
weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2]) weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2])
else:
weight = torch.bitwise_right_shift(
torch.unsqueeze(module.qweight, 1).expand(-1, 32 // bits, -1),
wf.unsqueeze(-1)).to(torch.int8)
weight = torch.bitwise_and(weight, (2 ** bits) - 1)
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
if act_order:
invalidInputError(module.g_idx.shape[0] == weight.shape[0],
"g_idx and weight shape mismatch")
_, g_id_map = torch.sort(module.g_idx)
weight = weight[g_id_map, :]
# convert weight to ggml format # convert weight to ggml format
weight = weight.reshape(weight.shape[0]//group_size, group_size, weight.shape[1]) weight = weight.reshape(weight.shape[0]//group_size, group_size, weight.shape[1])
@ -455,7 +482,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
FP16Linear, BF16Linear FP16Linear, BF16Linear
from ipex_llm.transformers.embedding import CPUEmbedding, DiskEmbedding, LowBitEmbedding from ipex_llm.transformers.embedding import CPUEmbedding, DiskEmbedding, LowBitEmbedding
has_been_replaced = False has_been_replaced = False
global _USE_VLLM_AWQ global _USE_VLLM_AWQ, _USE_VLLM_GPTQ
for name, module in model.named_children(): for name, module in model.named_children():
is_linear, linear_args = is_linear_module(module) is_linear, linear_args = is_linear_module(module)
@ -523,7 +550,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if has_bias: if has_bias:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device) .to(device)
elif _USE_VLLM_AWQ: elif _USE_VLLM_AWQ or _USE_VLLM_GPTQ:
# User load an AWQ quantized model from vLLM # User load an AWQ quantized model from vLLM
from ipex_llm.transformers.low_bit_linear import vLLMLowBitLinear from ipex_llm.transformers.low_bit_linear import vLLMLowBitLinear
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@ -571,7 +598,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
device = module.qweight.data.device device = module.qweight.data.device
invalidInputError(device.type != "meta", invalidInputError(device.type != "meta",
"converting from meta device is not supported") "converting from meta device is not supported")
weight, g_idx_map = convert_vllm_awq(module) weight, g_idx_map = convert_vllm_awq_or_gptq(module, gptq=_USE_VLLM_GPTQ,
act_order=act_order)
if act_order: if act_order:
new_linear.g_idx_map = g_idx_map new_linear.g_idx_map = g_idx_map
# Copy the weights # Copy the weights