Add vllm awq loading logic (#11950)
* add vllm awq loading logic * fix * refine
This commit is contained in:
parent
b38fb67bec
commit
0a7bd274e2
1 changed files with 131 additions and 1 deletions
|
|
@ -55,6 +55,7 @@ import sys
|
||||||
|
|
||||||
_IS_VLLM_AVAILABLE = None
|
_IS_VLLM_AVAILABLE = None
|
||||||
_USE_VLLM = False
|
_USE_VLLM = False
|
||||||
|
_USE_VLLM_AWQ = False
|
||||||
_VLLM_VERSION = None
|
_VLLM_VERSION = None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -143,7 +144,7 @@ def is_linear_module(module):
|
||||||
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
|
||||||
if is_vllm_available():
|
if is_vllm_available():
|
||||||
# Only convert vllm modules
|
# Only convert vllm modules
|
||||||
global _VLLM_VERSION
|
global _VLLM_VERSION, _USE_VLLM_AWQ
|
||||||
if _VLLM_VERSION is None:
|
if _VLLM_VERSION is None:
|
||||||
_VLLM_VERSION = get_package_version('vllm')
|
_VLLM_VERSION = get_package_version('vllm')
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
|
|
@ -180,6 +181,11 @@ def is_linear_module(module):
|
||||||
out_features = module.output_size
|
out_features = module.output_size
|
||||||
result = True
|
result = True
|
||||||
mp_group = None
|
mp_group = None
|
||||||
|
# Check for attribute qweight
|
||||||
|
if (not _USE_VLLM_AWQ
|
||||||
|
and hasattr(module.linear_method, "quant_config")
|
||||||
|
and module.linear_method.quant_config.get_name() == "awq"):
|
||||||
|
_USE_VLLM_AWQ = True
|
||||||
invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not"
|
invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not"
|
||||||
" support linear layers with skip_bias_add argument")
|
" support linear layers with skip_bias_add argument")
|
||||||
if isinstance(module, RowParallelLinear) and tp_size >= 2:
|
if isinstance(module, RowParallelLinear) and tp_size >= 2:
|
||||||
|
|
@ -286,6 +292,65 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
|
||||||
return new_linear
|
return new_linear
|
||||||
|
|
||||||
|
|
||||||
|
def convert_vllm_awq(module):
|
||||||
|
from ipex_llm.transformers.low_bit_linear import get_block_size
|
||||||
|
Q4_1 = get_block_size("asym_int4")
|
||||||
|
|
||||||
|
scales = module.scales
|
||||||
|
wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7],
|
||||||
|
dtype=torch.int32) * 4).unsqueeze(0)
|
||||||
|
# vLLM only supports load 4-bits model, so this has been checked
|
||||||
|
bits = 4
|
||||||
|
group_size = module.linear_method.quant_config.group_size
|
||||||
|
|
||||||
|
zeros = torch.bitwise_right_shift(
|
||||||
|
torch.unsqueeze(module.qzeros, 2).expand(-1, -1, 32 // bits),
|
||||||
|
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
|
||||||
|
zeros = torch.bitwise_and(zeros, (2 ** bits) - 1)
|
||||||
|
|
||||||
|
g_id_map = None
|
||||||
|
|
||||||
|
zeros = zeros.reshape(scales.shape)
|
||||||
|
|
||||||
|
weight = torch.bitwise_right_shift(
|
||||||
|
torch.unsqueeze(module.qweight, 2).expand(-1, -1, 32 // bits),
|
||||||
|
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
|
||||||
|
weight = torch.bitwise_and(weight, (2 ** bits) - 1)
|
||||||
|
weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2])
|
||||||
|
|
||||||
|
# convert weight to ggml format
|
||||||
|
weight = weight.reshape(weight.shape[0]//group_size, group_size, weight.shape[1])
|
||||||
|
weight = weight.permute(2, 0, 1).reshape(weight.shape[2], -1, 2, Q4_1//2)
|
||||||
|
weight = weight.transpose(2, 3)
|
||||||
|
weight = torch.bitwise_left_shift(weight,
|
||||||
|
torch.tensor([0, 4], dtype=torch.int8).reshape(1, 1, 1, 2))
|
||||||
|
weight = torch.bitwise_or(weight[:, :, :, 0], weight[:, :, :, 1]).contiguous()
|
||||||
|
|
||||||
|
# convert zeros to ggml format
|
||||||
|
zeros = zeros.reshape(-1, 1, zeros.shape[1]).permute(2, 0, 1)\
|
||||||
|
.unsqueeze(2)\
|
||||||
|
.expand(-1, -1, group_size//Q4_1, -1)\
|
||||||
|
.reshape(zeros.shape[1], -1, 1)\
|
||||||
|
.contiguous().to(torch.float16)
|
||||||
|
|
||||||
|
# convert scales to ggml format
|
||||||
|
scales = scales.reshape(-1, 1, scales.shape[1]).permute(2, 0, 1)\
|
||||||
|
.unsqueeze(2)\
|
||||||
|
.expand(-1, -1, group_size//Q4_1, -1)\
|
||||||
|
.reshape(scales.shape[-1], -1, 1)\
|
||||||
|
.contiguous().to(torch.float16)
|
||||||
|
|
||||||
|
m = -(zeros * scales)
|
||||||
|
d = scales
|
||||||
|
|
||||||
|
ggml_weight = torch.cat([d.view(torch.uint8),
|
||||||
|
m.view(torch.uint8),
|
||||||
|
weight.view(torch.uint8)], dim=-1)
|
||||||
|
ggml_weight = ggml_weight.reshape([-1])
|
||||||
|
|
||||||
|
return ggml_weight, g_id_map
|
||||||
|
|
||||||
|
|
||||||
def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
|
def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
|
||||||
from ipex_llm.transformers.low_bit_linear import get_block_size
|
from ipex_llm.transformers.low_bit_linear import get_block_size
|
||||||
Q4_1 = get_block_size("asym_int4")
|
Q4_1 = get_block_size("asym_int4")
|
||||||
|
|
@ -391,6 +456,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
FP16Linear, BF16Linear
|
FP16Linear, BF16Linear
|
||||||
from ipex_llm.transformers.embedding import CPUEmbedding, DiskEmbedding, LowBitEmbedding
|
from ipex_llm.transformers.embedding import CPUEmbedding, DiskEmbedding, LowBitEmbedding
|
||||||
has_been_replaced = False
|
has_been_replaced = False
|
||||||
|
global _USE_VLLM_AWQ
|
||||||
|
|
||||||
for name, module in model.named_children():
|
for name, module in model.named_children():
|
||||||
is_linear, linear_args = is_linear_module(module)
|
is_linear, linear_args = is_linear_module(module)
|
||||||
|
|
@ -452,6 +518,70 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
if has_bias:
|
if has_bias:
|
||||||
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
.to(device)
|
.to(device)
|
||||||
|
elif _USE_VLLM_AWQ:
|
||||||
|
# User load an AWQ quantized model from vLLM
|
||||||
|
from ipex_llm.transformers.low_bit_linear import vLLMLowBitLinear
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
has_bias = module.bias is not None and module.bias.abs().sum() != 0
|
||||||
|
if isinstance(module, ParallelLMHead):
|
||||||
|
new_linear = LowBitLinear(
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
qtype=qtype,
|
||||||
|
bias=has_bias,
|
||||||
|
mp_group=mp_group,
|
||||||
|
enable_xetla=enable_xetla,
|
||||||
|
optimize_lm_head=False,
|
||||||
|
act_order=act_order,
|
||||||
|
enable_scale_search=enable_scale_search,
|
||||||
|
)
|
||||||
|
device = module.weight.data.device
|
||||||
|
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
|
||||||
|
full_module_name,
|
||||||
|
imatrix_data,
|
||||||
|
model_config)
|
||||||
|
# Copy the weights
|
||||||
|
paramsLowBit = FP4Params(data=module.weight.data,
|
||||||
|
requires_grad=False,
|
||||||
|
quantized=False,
|
||||||
|
_shape=None,
|
||||||
|
convert_shape_only=convert_shape_only,
|
||||||
|
qtype=cur_qtype,
|
||||||
|
imatrix=cur_imatrix,
|
||||||
|
in_features=in_features,
|
||||||
|
enable_xetla=enable_xetla,
|
||||||
|
enable_scale_search=enable_scale_search).to(device)
|
||||||
|
else:
|
||||||
|
new_linear = vLLMLowBitLinear(
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
qtype=qtype,
|
||||||
|
bias=has_bias,
|
||||||
|
mp_group=mp_group,
|
||||||
|
enable_xetla=enable_xetla,
|
||||||
|
optimize_lm_head=False,
|
||||||
|
act_order=act_order,
|
||||||
|
enable_scale_search=enable_scale_search,
|
||||||
|
)
|
||||||
|
device = module.qweight.data.device
|
||||||
|
invalidInputError(device.type != "meta",
|
||||||
|
"converting from meta device is not supported")
|
||||||
|
weight, g_idx_map = convert_vllm_awq(module)
|
||||||
|
if act_order:
|
||||||
|
new_linear.g_idx_map = g_idx_map
|
||||||
|
# Copy the weights
|
||||||
|
paramsLowBit = FP4Params(data=weight,
|
||||||
|
requires_grad=False,
|
||||||
|
quantized=True,
|
||||||
|
_shape=(out_features, in_features),
|
||||||
|
convert_shape_only=convert_shape_only,
|
||||||
|
qtype=qtype,
|
||||||
|
enable_xetla=enable_xetla,
|
||||||
|
enable_scale_search=enable_scale_search).to(device)
|
||||||
|
new_linear._parameters['weight'] = paramsLowBit
|
||||||
|
if has_bias:
|
||||||
|
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
|
||||||
|
.to(device)
|
||||||
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
|
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
|
||||||
if in_features % 64 != 0:
|
if in_features % 64 != 0:
|
||||||
# now our kernel requires in_features is a multiple of 64
|
# now our kernel requires in_features is a multiple of 64
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue