Add vllm awq loading logic (#11950)
* add vllm awq loading logic * fix * refine
This commit is contained in:
parent
b38fb67bec
commit
0a7bd274e2
1 changed files with 131 additions and 1 deletions
|
|
@ -55,6 +55,7 @@ import sys
|
|||
|
||||
_IS_VLLM_AVAILABLE = None
|
||||
_USE_VLLM = False
|
||||
_USE_VLLM_AWQ = False
|
||||
_VLLM_VERSION = None
|
||||
|
||||
|
||||
|
|
@ -143,7 +144,7 @@ def is_linear_module(module):
|
|||
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
||||
if is_vllm_available():
|
||||
# Only convert vllm modules
|
||||
global _VLLM_VERSION
|
||||
global _VLLM_VERSION, _USE_VLLM_AWQ
|
||||
if _VLLM_VERSION is None:
|
||||
_VLLM_VERSION = get_package_version('vllm')
|
||||
from vllm.model_executor.layers.linear import (
|
||||
|
|
@ -180,6 +181,11 @@ def is_linear_module(module):
|
|||
out_features = module.output_size
|
||||
result = True
|
||||
mp_group = None
|
||||
# Check for attribute qweight
|
||||
if (not _USE_VLLM_AWQ
|
||||
and hasattr(module.linear_method, "quant_config")
|
||||
and module.linear_method.quant_config.get_name() == "awq"):
|
||||
_USE_VLLM_AWQ = True
|
||||
invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not"
|
||||
" support linear layers with skip_bias_add argument")
|
||||
if isinstance(module, RowParallelLinear) and tp_size >= 2:
|
||||
|
|
@ -286,6 +292,65 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
|
|||
return new_linear
|
||||
|
||||
|
||||
def convert_vllm_awq(module):
|
||||
from ipex_llm.transformers.low_bit_linear import get_block_size
|
||||
Q4_1 = get_block_size("asym_int4")
|
||||
|
||||
scales = module.scales
|
||||
wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7],
|
||||
dtype=torch.int32) * 4).unsqueeze(0)
|
||||
# vLLM only supports load 4-bits model, so this has been checked
|
||||
bits = 4
|
||||
group_size = module.linear_method.quant_config.group_size
|
||||
|
||||
zeros = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(module.qzeros, 2).expand(-1, -1, 32 // bits),
|
||||
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
|
||||
zeros = torch.bitwise_and(zeros, (2 ** bits) - 1)
|
||||
|
||||
g_id_map = None
|
||||
|
||||
zeros = zeros.reshape(scales.shape)
|
||||
|
||||
weight = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(module.qweight, 2).expand(-1, -1, 32 // bits),
|
||||
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
|
||||
weight = torch.bitwise_and(weight, (2 ** bits) - 1)
|
||||
weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2])
|
||||
|
||||
# convert weight to ggml format
|
||||
weight = weight.reshape(weight.shape[0]//group_size, group_size, weight.shape[1])
|
||||
weight = weight.permute(2, 0, 1).reshape(weight.shape[2], -1, 2, Q4_1//2)
|
||||
weight = weight.transpose(2, 3)
|
||||
weight = torch.bitwise_left_shift(weight,
|
||||
torch.tensor([0, 4], dtype=torch.int8).reshape(1, 1, 1, 2))
|
||||
weight = torch.bitwise_or(weight[:, :, :, 0], weight[:, :, :, 1]).contiguous()
|
||||
|
||||
# convert zeros to ggml format
|
||||
zeros = zeros.reshape(-1, 1, zeros.shape[1]).permute(2, 0, 1)\
|
||||
.unsqueeze(2)\
|
||||
.expand(-1, -1, group_size//Q4_1, -1)\
|
||||
.reshape(zeros.shape[1], -1, 1)\
|
||||
.contiguous().to(torch.float16)
|
||||
|
||||
# convert scales to ggml format
|
||||
scales = scales.reshape(-1, 1, scales.shape[1]).permute(2, 0, 1)\
|
||||
.unsqueeze(2)\
|
||||
.expand(-1, -1, group_size//Q4_1, -1)\
|
||||
.reshape(scales.shape[-1], -1, 1)\
|
||||
.contiguous().to(torch.float16)
|
||||
|
||||
m = -(zeros * scales)
|
||||
d = scales
|
||||
|
||||
ggml_weight = torch.cat([d.view(torch.uint8),
|
||||
m.view(torch.uint8),
|
||||
weight.view(torch.uint8)], dim=-1)
|
||||
ggml_weight = ggml_weight.reshape([-1])
|
||||
|
||||
return ggml_weight, g_id_map
|
||||
|
||||
|
||||
def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
|
||||
from ipex_llm.transformers.low_bit_linear import get_block_size
|
||||
Q4_1 = get_block_size("asym_int4")
|
||||
|
|
@ -391,6 +456,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
FP16Linear, BF16Linear
|
||||
from ipex_llm.transformers.embedding import CPUEmbedding, DiskEmbedding, LowBitEmbedding
|
||||
has_been_replaced = False
|
||||
global _USE_VLLM_AWQ
|
||||
|
||||
for name, module in model.named_children():
|
||||
is_linear, linear_args = is_linear_module(module)
|
||||
|
|
@ -452,6 +518,70 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
if has_bias:
|
||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||
.to(device)
|
||||
elif _USE_VLLM_AWQ:
|
||||
# User load an AWQ quantized model from vLLM
|
||||
from ipex_llm.transformers.low_bit_linear import vLLMLowBitLinear
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
has_bias = module.bias is not None and module.bias.abs().sum() != 0
|
||||
if isinstance(module, ParallelLMHead):
|
||||
new_linear = LowBitLinear(
|
||||
in_features,
|
||||
out_features,
|
||||
qtype=qtype,
|
||||
bias=has_bias,
|
||||
mp_group=mp_group,
|
||||
enable_xetla=enable_xetla,
|
||||
optimize_lm_head=False,
|
||||
act_order=act_order,
|
||||
enable_scale_search=enable_scale_search,
|
||||
)
|
||||
device = module.weight.data.device
|
||||
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
|
||||
full_module_name,
|
||||
imatrix_data,
|
||||
model_config)
|
||||
# Copy the weights
|
||||
paramsLowBit = FP4Params(data=module.weight.data,
|
||||
requires_grad=False,
|
||||
quantized=False,
|
||||
_shape=None,
|
||||
convert_shape_only=convert_shape_only,
|
||||
qtype=cur_qtype,
|
||||
imatrix=cur_imatrix,
|
||||
in_features=in_features,
|
||||
enable_xetla=enable_xetla,
|
||||
enable_scale_search=enable_scale_search).to(device)
|
||||
else:
|
||||
new_linear = vLLMLowBitLinear(
|
||||
in_features,
|
||||
out_features,
|
||||
qtype=qtype,
|
||||
bias=has_bias,
|
||||
mp_group=mp_group,
|
||||
enable_xetla=enable_xetla,
|
||||
optimize_lm_head=False,
|
||||
act_order=act_order,
|
||||
enable_scale_search=enable_scale_search,
|
||||
)
|
||||
device = module.qweight.data.device
|
||||
invalidInputError(device.type != "meta",
|
||||
"converting from meta device is not supported")
|
||||
weight, g_idx_map = convert_vllm_awq(module)
|
||||
if act_order:
|
||||
new_linear.g_idx_map = g_idx_map
|
||||
# Copy the weights
|
||||
paramsLowBit = FP4Params(data=weight,
|
||||
requires_grad=False,
|
||||
quantized=True,
|
||||
_shape=(out_features, in_features),
|
||||
convert_shape_only=convert_shape_only,
|
||||
qtype=qtype,
|
||||
enable_xetla=enable_xetla,
|
||||
enable_scale_search=enable_scale_search).to(device)
|
||||
new_linear._parameters['weight'] = paramsLowBit
|
||||
if has_bias:
|
||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||
.to(device)
|
||||
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
|
||||
if in_features % 64 != 0:
|
||||
# now our kernel requires in_features is a multiple of 64
|
||||
|
|
|
|||
Loading…
Reference in a new issue