parent
750d4ad5dc
commit
e70ae0638e
3 changed files with 164 additions and 26 deletions
|
|
@ -180,6 +180,8 @@ def is_linear_module(module):
|
|||
out_features = module.output_size
|
||||
result = True
|
||||
mp_group = None
|
||||
invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not"
|
||||
" support linear layers with skip_bias_add argument")
|
||||
if isinstance(module, RowParallelLinear) and tp_size >= 2:
|
||||
mp_group = get_tensor_model_parallel_group()
|
||||
in_features = module.input_size_per_partition
|
||||
|
|
@ -218,6 +220,70 @@ def is_linear_module(module):
|
|||
return result, (in_features, out_features, mp_group)
|
||||
|
||||
|
||||
def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
|
||||
enable_xetla, optimize_lm_head, enable_scale_search):
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from ipex_llm.transformers.low_bit_linear import LowBitLinear, \
|
||||
FP16Linear, BF16Linear, vLLMLowBitLinear, vLLMFP16Linear, vLLMBF16Linear
|
||||
if isinstance(module, ParallelLMHead):
|
||||
if qtype == ggml_tensor_qtype["fp16"]:
|
||||
new_linear = FP16Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
optimize_lm_head=optimize_lm_head
|
||||
)
|
||||
elif qtype == ggml_tensor_qtype["bf16"]:
|
||||
new_linear = BF16Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
optimize_lm_head=optimize_lm_head
|
||||
)
|
||||
else:
|
||||
new_linear = LowBitLinear(
|
||||
in_features,
|
||||
out_features,
|
||||
cur_qtype,
|
||||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
enable_xetla=enable_xetla,
|
||||
optimize_lm_head=optimize_lm_head,
|
||||
enable_scale_search=enable_scale_search,
|
||||
)
|
||||
else:
|
||||
if qtype == ggml_tensor_qtype["fp16"]:
|
||||
new_linear = vLLMFP16Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
optimize_lm_head=optimize_lm_head
|
||||
)
|
||||
elif qtype == ggml_tensor_qtype["bf16"]:
|
||||
new_linear = vLLMBF16Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
optimize_lm_head=optimize_lm_head
|
||||
)
|
||||
else:
|
||||
new_linear = vLLMLowBitLinear(
|
||||
in_features,
|
||||
out_features,
|
||||
cur_qtype,
|
||||
module.bias is not None,
|
||||
mp_group=mp_group,
|
||||
enable_xetla=enable_xetla,
|
||||
optimize_lm_head=optimize_lm_head,
|
||||
enable_scale_search=enable_scale_search,
|
||||
)
|
||||
return new_linear
|
||||
|
||||
|
||||
def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
|
||||
from ipex_llm.transformers.low_bit_linear import get_block_size
|
||||
Q4_1 = get_block_size("asym_int4")
|
||||
|
|
@ -399,6 +465,17 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
# check hidden size whether is a multiple of 256
|
||||
cur_qtype = check_hidden_size(cur_qtype, in_features)
|
||||
|
||||
if _USE_VLLM:
|
||||
new_linear = convert_vllm(module,
|
||||
qtype,
|
||||
in_features,
|
||||
out_features,
|
||||
mp_group,
|
||||
cur_qtype,
|
||||
enable_xetla,
|
||||
optimize_lm_head,
|
||||
enable_scale_search)
|
||||
else:
|
||||
new_linear = LowBitLinear(
|
||||
in_features,
|
||||
out_features,
|
||||
|
|
@ -427,6 +504,19 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
.to(device)
|
||||
elif qtype == ggml_tensor_qtype["fp16"]:
|
||||
module.to(torch.float16)
|
||||
if _USE_VLLM:
|
||||
new_linear = convert_vllm(
|
||||
module,
|
||||
qtype,
|
||||
in_features,
|
||||
out_features,
|
||||
mp_group,
|
||||
None,
|
||||
None,
|
||||
optimize_lm_head,
|
||||
None
|
||||
)
|
||||
else:
|
||||
new_linear = FP16Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
|
|
@ -449,6 +539,19 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
.to(device)
|
||||
elif qtype == ggml_tensor_qtype["bf16"]:
|
||||
module.to(torch.bfloat16)
|
||||
if _USE_VLLM:
|
||||
new_linear = convert_vllm(
|
||||
module,
|
||||
qtype,
|
||||
in_features,
|
||||
out_features,
|
||||
mp_group,
|
||||
None,
|
||||
None,
|
||||
optimize_lm_head,
|
||||
None
|
||||
)
|
||||
else:
|
||||
new_linear = BF16Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
|
|
|
|||
|
|
@ -1009,3 +1009,38 @@ class BF16Linear(nn.Linear):
|
|||
result = result.reshape(*original_shape[:-1], result.shape[-1])
|
||||
|
||||
return result.to(x.dtype)
|
||||
|
||||
|
||||
class vLLMLowBitLinear(LowBitLinear):
|
||||
def __init__(self, input_features, output_features, qtype, bias=True,
|
||||
conver_to_half=True, mp_group=None, enable_xetla=False,
|
||||
optimize_lm_head=False, act_order=False,
|
||||
enable_scale_search=False):
|
||||
super().__init__(input_features, output_features, qtype, bias, conver_to_half, mp_group,
|
||||
enable_xetla, optimize_lm_head, act_order, enable_scale_search)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
result = super().forward(x)
|
||||
return result, None
|
||||
|
||||
|
||||
class vLLMFP16Linear(FP16Linear):
|
||||
def __init__(self, input_features, output_features, bias=True, mp_group=None, weight_type=1,
|
||||
enable_xetla=False, optimize_lm_head=False):
|
||||
super().__init__(input_features, output_features, bias, mp_group, weight_type,
|
||||
enable_xetla, optimize_lm_head)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
result = super().forward(x)
|
||||
return result, None
|
||||
|
||||
|
||||
class vLLMBF16Linear(BF16Linear):
|
||||
def __init__(self, input_features, output_features, bias=True, mp_group=None,
|
||||
compute_dtype=None, enable_xetla=False, optimize_lm_head=False):
|
||||
super().__init__(input_features, output_features, bias, mp_group, compute_dtype,
|
||||
enable_xetla, optimize_lm_head)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
result = super().forward(x)
|
||||
return result, None
|
||||
|
|
|
|||
|
|
@ -225,8 +225,8 @@ def _ipex_llm_convert(load_in_low_bit):
|
|||
|
||||
def get_load_function(low_bit):
|
||||
def _ipex_llm_load_model(self) -> None:
|
||||
_model_mlp_convert()
|
||||
_model_attention_convert()
|
||||
# _model_mlp_convert()
|
||||
# _model_attention_convert()
|
||||
_model_sample_convert()
|
||||
|
||||
from vllm.utils import measure_device_memory
|
||||
|
|
|
|||
Loading…
Reference in a new issue