diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index ac6081e1..1adc00b9 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -180,6 +180,8 @@ def is_linear_module(module): out_features = module.output_size result = True mp_group = None + invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not" + " support linear layers with skip_bias_add argument") if isinstance(module, RowParallelLinear) and tp_size >= 2: mp_group = get_tensor_model_parallel_group() in_features = module.input_size_per_partition @@ -218,6 +220,70 @@ def is_linear_module(module): return result, (in_features, out_features, mp_group) +def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype, + enable_xetla, optimize_lm_head, enable_scale_search): + from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead + from ipex_llm.transformers.low_bit_linear import LowBitLinear, \ + FP16Linear, BF16Linear, vLLMLowBitLinear, vLLMFP16Linear, vLLMBF16Linear + if isinstance(module, ParallelLMHead): + if qtype == ggml_tensor_qtype["fp16"]: + new_linear = FP16Linear( + in_features, + out_features, + module.bias is not None, + mp_group=mp_group, + optimize_lm_head=optimize_lm_head + ) + elif qtype == ggml_tensor_qtype["bf16"]: + new_linear = BF16Linear( + in_features, + out_features, + module.bias is not None, + mp_group=mp_group, + optimize_lm_head=optimize_lm_head + ) + else: + new_linear = LowBitLinear( + in_features, + out_features, + cur_qtype, + module.bias is not None, + mp_group=mp_group, + enable_xetla=enable_xetla, + optimize_lm_head=optimize_lm_head, + enable_scale_search=enable_scale_search, + ) + else: + if qtype == ggml_tensor_qtype["fp16"]: + new_linear = vLLMFP16Linear( + in_features, + out_features, + module.bias is not None, + mp_group=mp_group, + optimize_lm_head=optimize_lm_head + ) + elif qtype == ggml_tensor_qtype["bf16"]: + new_linear = vLLMBF16Linear( + in_features, + out_features, + module.bias is not None, + mp_group=mp_group, + optimize_lm_head=optimize_lm_head + ) + else: + new_linear = vLLMLowBitLinear( + in_features, + out_features, + cur_qtype, + module.bias is not None, + mp_group=mp_group, + enable_xetla=enable_xetla, + optimize_lm_head=optimize_lm_head, + enable_scale_search=enable_scale_search, + ) + return new_linear + + def convert_gptq(module, awq=False, llm_awq=False, act_order=False): from ipex_llm.transformers.low_bit_linear import get_block_size Q4_1 = get_block_size("asym_int4") @@ -399,16 +465,27 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, # check hidden size whether is a multiple of 256 cur_qtype = check_hidden_size(cur_qtype, in_features) - new_linear = LowBitLinear( - in_features, - out_features, - cur_qtype, - module.bias is not None, - mp_group=mp_group, - enable_xetla=enable_xetla, - optimize_lm_head=optimize_lm_head, - enable_scale_search=enable_scale_search, - ) + if _USE_VLLM: + new_linear = convert_vllm(module, + qtype, + in_features, + out_features, + mp_group, + cur_qtype, + enable_xetla, + optimize_lm_head, + enable_scale_search) + else: + new_linear = LowBitLinear( + in_features, + out_features, + cur_qtype, + module.bias is not None, + mp_group=mp_group, + enable_xetla=enable_xetla, + optimize_lm_head=optimize_lm_head, + enable_scale_search=enable_scale_search, + ) device = module.weight.data.device # Copy the weights paramsLowBit = FP4Params(data=module.weight.data, @@ -427,13 +504,26 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, .to(device) elif qtype == ggml_tensor_qtype["fp16"]: module.to(torch.float16) - new_linear = FP16Linear( - in_features, - out_features, - module.bias is not None, - mp_group=mp_group, - optimize_lm_head=optimize_lm_head - ) + if _USE_VLLM: + new_linear = convert_vllm( + module, + qtype, + in_features, + out_features, + mp_group, + None, + None, + optimize_lm_head, + None + ) + else: + new_linear = FP16Linear( + in_features, + out_features, + module.bias is not None, + mp_group=mp_group, + optimize_lm_head=optimize_lm_head + ) device = module.weight.data.device from ipex_llm.transformers.utils import get_ipex_version if get_ipex_version() < "2.1.10+xpu": @@ -449,13 +539,26 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, .to(device) elif qtype == ggml_tensor_qtype["bf16"]: module.to(torch.bfloat16) - new_linear = BF16Linear( - in_features, - out_features, - module.bias is not None, - mp_group=mp_group, - optimize_lm_head=optimize_lm_head - ) + if _USE_VLLM: + new_linear = convert_vllm( + module, + qtype, + in_features, + out_features, + mp_group, + None, + None, + optimize_lm_head, + None + ) + else: + new_linear = BF16Linear( + in_features, + out_features, + module.bias is not None, + mp_group=mp_group, + optimize_lm_head=optimize_lm_head + ) device = module.weight.data.device # convert here new_linear._parameters['weight'] = nn.Parameter(module.weight) diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index aacd288c..fddcb7c9 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -1009,3 +1009,38 @@ class BF16Linear(nn.Linear): result = result.reshape(*original_shape[:-1], result.shape[-1]) return result.to(x.dtype) + + +class vLLMLowBitLinear(LowBitLinear): + def __init__(self, input_features, output_features, qtype, bias=True, + conver_to_half=True, mp_group=None, enable_xetla=False, + optimize_lm_head=False, act_order=False, + enable_scale_search=False): + super().__init__(input_features, output_features, qtype, bias, conver_to_half, mp_group, + enable_xetla, optimize_lm_head, act_order, enable_scale_search) + + def forward(self, x: torch.Tensor): + result = super().forward(x) + return result, None + + +class vLLMFP16Linear(FP16Linear): + def __init__(self, input_features, output_features, bias=True, mp_group=None, weight_type=1, + enable_xetla=False, optimize_lm_head=False): + super().__init__(input_features, output_features, bias, mp_group, weight_type, + enable_xetla, optimize_lm_head) + + def forward(self, x: torch.Tensor): + result = super().forward(x) + return result, None + + +class vLLMBF16Linear(BF16Linear): + def __init__(self, input_features, output_features, bias=True, mp_group=None, + compute_dtype=None, enable_xetla=False, optimize_lm_head=False): + super().__init__(input_features, output_features, bias, mp_group, compute_dtype, + enable_xetla, optimize_lm_head) + + def forward(self, x: torch.Tensor): + result = super().forward(x) + return result, None diff --git a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py index 85719ccd..79b56c3b 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py @@ -225,8 +225,8 @@ def _ipex_llm_convert(load_in_low_bit): def get_load_function(low_bit): def _ipex_llm_load_model(self) -> None: - _model_mlp_convert() - _model_attention_convert() + # _model_mlp_convert() + # _model_attention_convert() _model_sample_convert() from vllm.utils import measure_device_memory