Fix vLLM not convert issues (#11817)

* Fix not convert issues

* refine
This commit is contained in:
Guancheng Fu 2024-08-15 19:04:05 +08:00 committed by GitHub
parent 750d4ad5dc
commit e70ae0638e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 164 additions and 26 deletions

View file

@ -180,6 +180,8 @@ def is_linear_module(module):
out_features = module.output_size
result = True
mp_group = None
invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not"
" support linear layers with skip_bias_add argument")
if isinstance(module, RowParallelLinear) and tp_size >= 2:
mp_group = get_tensor_model_parallel_group()
in_features = module.input_size_per_partition
@ -218,6 +220,70 @@ def is_linear_module(module):
return result, (in_features, out_features, mp_group)
def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
enable_xetla, optimize_lm_head, enable_scale_search):
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from ipex_llm.transformers.low_bit_linear import LowBitLinear, \
FP16Linear, BF16Linear, vLLMLowBitLinear, vLLMFP16Linear, vLLMBF16Linear
if isinstance(module, ParallelLMHead):
if qtype == ggml_tensor_qtype["fp16"]:
new_linear = FP16Linear(
in_features,
out_features,
module.bias is not None,
mp_group=mp_group,
optimize_lm_head=optimize_lm_head
)
elif qtype == ggml_tensor_qtype["bf16"]:
new_linear = BF16Linear(
in_features,
out_features,
module.bias is not None,
mp_group=mp_group,
optimize_lm_head=optimize_lm_head
)
else:
new_linear = LowBitLinear(
in_features,
out_features,
cur_qtype,
module.bias is not None,
mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head,
enable_scale_search=enable_scale_search,
)
else:
if qtype == ggml_tensor_qtype["fp16"]:
new_linear = vLLMFP16Linear(
in_features,
out_features,
module.bias is not None,
mp_group=mp_group,
optimize_lm_head=optimize_lm_head
)
elif qtype == ggml_tensor_qtype["bf16"]:
new_linear = vLLMBF16Linear(
in_features,
out_features,
module.bias is not None,
mp_group=mp_group,
optimize_lm_head=optimize_lm_head
)
else:
new_linear = vLLMLowBitLinear(
in_features,
out_features,
cur_qtype,
module.bias is not None,
mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head,
enable_scale_search=enable_scale_search,
)
return new_linear
def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
from ipex_llm.transformers.low_bit_linear import get_block_size
Q4_1 = get_block_size("asym_int4")
@ -399,6 +465,17 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
# check hidden size whether is a multiple of 256
cur_qtype = check_hidden_size(cur_qtype, in_features)
if _USE_VLLM:
new_linear = convert_vllm(module,
qtype,
in_features,
out_features,
mp_group,
cur_qtype,
enable_xetla,
optimize_lm_head,
enable_scale_search)
else:
new_linear = LowBitLinear(
in_features,
out_features,
@ -427,6 +504,19 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
.to(device)
elif qtype == ggml_tensor_qtype["fp16"]:
module.to(torch.float16)
if _USE_VLLM:
new_linear = convert_vllm(
module,
qtype,
in_features,
out_features,
mp_group,
None,
None,
optimize_lm_head,
None
)
else:
new_linear = FP16Linear(
in_features,
out_features,
@ -449,6 +539,19 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
.to(device)
elif qtype == ggml_tensor_qtype["bf16"]:
module.to(torch.bfloat16)
if _USE_VLLM:
new_linear = convert_vllm(
module,
qtype,
in_features,
out_features,
mp_group,
None,
None,
optimize_lm_head,
None
)
else:
new_linear = BF16Linear(
in_features,
out_features,

View file

@ -1009,3 +1009,38 @@ class BF16Linear(nn.Linear):
result = result.reshape(*original_shape[:-1], result.shape[-1])
return result.to(x.dtype)
class vLLMLowBitLinear(LowBitLinear):
def __init__(self, input_features, output_features, qtype, bias=True,
conver_to_half=True, mp_group=None, enable_xetla=False,
optimize_lm_head=False, act_order=False,
enable_scale_search=False):
super().__init__(input_features, output_features, qtype, bias, conver_to_half, mp_group,
enable_xetla, optimize_lm_head, act_order, enable_scale_search)
def forward(self, x: torch.Tensor):
result = super().forward(x)
return result, None
class vLLMFP16Linear(FP16Linear):
def __init__(self, input_features, output_features, bias=True, mp_group=None, weight_type=1,
enable_xetla=False, optimize_lm_head=False):
super().__init__(input_features, output_features, bias, mp_group, weight_type,
enable_xetla, optimize_lm_head)
def forward(self, x: torch.Tensor):
result = super().forward(x)
return result, None
class vLLMBF16Linear(BF16Linear):
def __init__(self, input_features, output_features, bias=True, mp_group=None,
compute_dtype=None, enable_xetla=False, optimize_lm_head=False):
super().__init__(input_features, output_features, bias, mp_group, compute_dtype,
enable_xetla, optimize_lm_head)
def forward(self, x: torch.Tensor):
result = super().forward(x)
return result, None

View file

@ -225,8 +225,8 @@ def _ipex_llm_convert(load_in_low_bit):
def get_load_function(low_bit):
def _ipex_llm_load_model(self) -> None:
_model_mlp_convert()
_model_attention_convert()
# _model_mlp_convert()
# _model_attention_convert()
_model_sample_convert()
from vllm.utils import measure_device_memory