Enable vllm load gptq model (#12083)
* enable vllm load gptq model * update * update * update * update style
This commit is contained in:
parent
c2774e1a43
commit
40e463c66b
1 changed files with 41 additions and 13 deletions
|
|
@ -56,6 +56,7 @@ import sys
|
|||
_IS_VLLM_AVAILABLE = None
|
||||
_USE_VLLM = False
|
||||
_USE_VLLM_AWQ = False
|
||||
_USE_VLLM_GPTQ = False
|
||||
_VLLM_VERSION = None
|
||||
|
||||
|
||||
|
|
@ -144,7 +145,7 @@ def is_linear_module(module):
|
|||
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
||||
if is_vllm_available():
|
||||
# Only convert vllm modules
|
||||
global _VLLM_VERSION, _USE_VLLM_AWQ
|
||||
global _VLLM_VERSION, _USE_VLLM_AWQ, _USE_VLLM_GPTQ
|
||||
if _VLLM_VERSION is None:
|
||||
_VLLM_VERSION = get_package_version('vllm')
|
||||
from vllm.model_executor.layers.linear import (
|
||||
|
|
@ -186,6 +187,10 @@ def is_linear_module(module):
|
|||
and hasattr(module.quant_method, "quant_config")
|
||||
and module.quant_method.quant_config.get_name() == "awq"):
|
||||
_USE_VLLM_AWQ = True
|
||||
if (not _USE_VLLM_GPTQ
|
||||
and hasattr(module.quant_method, "quant_config")
|
||||
and module.quant_method.quant_config.get_name() == "gptq"):
|
||||
_USE_VLLM_GPTQ = True
|
||||
invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not"
|
||||
" support linear layers with skip_bias_add argument")
|
||||
if isinstance(module, RowParallelLinear) and tp_size >= 2:
|
||||
|
|
@ -291,16 +296,23 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
|
|||
return new_linear
|
||||
|
||||
|
||||
def convert_vllm_awq(module):
|
||||
def convert_vllm_awq_or_gptq(module, gptq=False, act_order=False):
|
||||
from ipex_llm.transformers.low_bit_linear import get_block_size
|
||||
Q4_1 = get_block_size("asym_int4")
|
||||
|
||||
scales = module.scales
|
||||
# vLLM only supports load 4-bits model, so this has been checked
|
||||
if gptq:
|
||||
bits = module.quant_method.quant_config.weight_bits
|
||||
wf = (torch.tensor([0, 1, 2, 3, 4, 5, 6, 7],
|
||||
dtype=torch.int32) * 4).unsqueeze(0)
|
||||
else:
|
||||
bits = 4
|
||||
wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7],
|
||||
dtype=torch.int32) * 4).unsqueeze(0)
|
||||
# vLLM only supports load 4-bits model, so this has been checked
|
||||
bits = 4
|
||||
group_size = module.quant_method.quant_config.group_size
|
||||
if int(group_size) % Q4_1 != 0:
|
||||
invalidInputError(False, (f"group_size:{group_size} must be divisible by "f"{Q4_1}."))
|
||||
|
||||
zeros = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(module.qzeros, 2).expand(-1, -1, 32 // bits),
|
||||
|
|
@ -309,13 +321,28 @@ def convert_vllm_awq(module):
|
|||
|
||||
g_id_map = None
|
||||
|
||||
if gptq:
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(scales.shape)
|
||||
|
||||
if not gptq:
|
||||
weight = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(module.qweight, 2).expand(-1, -1, 32 // bits),
|
||||
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
|
||||
weight = torch.bitwise_and(weight, (2 ** bits) - 1)
|
||||
weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2])
|
||||
else:
|
||||
weight = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(module.qweight, 1).expand(-1, 32 // bits, -1),
|
||||
wf.unsqueeze(-1)).to(torch.int8)
|
||||
weight = torch.bitwise_and(weight, (2 ** bits) - 1)
|
||||
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
|
||||
|
||||
if act_order:
|
||||
invalidInputError(module.g_idx.shape[0] == weight.shape[0],
|
||||
"g_idx and weight shape mismatch")
|
||||
_, g_id_map = torch.sort(module.g_idx)
|
||||
weight = weight[g_id_map, :]
|
||||
|
||||
# convert weight to ggml format
|
||||
weight = weight.reshape(weight.shape[0]//group_size, group_size, weight.shape[1])
|
||||
|
|
@ -455,7 +482,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
FP16Linear, BF16Linear
|
||||
from ipex_llm.transformers.embedding import CPUEmbedding, DiskEmbedding, LowBitEmbedding
|
||||
has_been_replaced = False
|
||||
global _USE_VLLM_AWQ
|
||||
global _USE_VLLM_AWQ, _USE_VLLM_GPTQ
|
||||
|
||||
for name, module in model.named_children():
|
||||
is_linear, linear_args = is_linear_module(module)
|
||||
|
|
@ -523,7 +550,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
if has_bias:
|
||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||
.to(device)
|
||||
elif _USE_VLLM_AWQ:
|
||||
elif _USE_VLLM_AWQ or _USE_VLLM_GPTQ:
|
||||
# User load an AWQ quantized model from vLLM
|
||||
from ipex_llm.transformers.low_bit_linear import vLLMLowBitLinear
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
|
@ -571,7 +598,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
device = module.qweight.data.device
|
||||
invalidInputError(device.type != "meta",
|
||||
"converting from meta device is not supported")
|
||||
weight, g_idx_map = convert_vllm_awq(module)
|
||||
weight, g_idx_map = convert_vllm_awq_or_gptq(module, gptq=_USE_VLLM_GPTQ,
|
||||
act_order=act_order)
|
||||
if act_order:
|
||||
new_linear.g_idx_map = g_idx_map
|
||||
# Copy the weights
|
||||
|
|
|
|||
Loading…
Reference in a new issue