parent
							
								
									750d4ad5dc
								
							
						
					
					
						commit
						e70ae0638e
					
				
					 3 changed files with 164 additions and 26 deletions
				
			
		| 
						 | 
				
			
			@ -180,6 +180,8 @@ def is_linear_module(module):
 | 
			
		|||
            out_features = module.output_size
 | 
			
		||||
            result = True
 | 
			
		||||
            mp_group = None
 | 
			
		||||
            invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not"
 | 
			
		||||
                              " support linear layers with skip_bias_add argument")
 | 
			
		||||
            if isinstance(module, RowParallelLinear) and tp_size >= 2:
 | 
			
		||||
                mp_group = get_tensor_model_parallel_group()
 | 
			
		||||
                in_features = module.input_size_per_partition
 | 
			
		||||
| 
						 | 
				
			
			@ -218,6 +220,70 @@ def is_linear_module(module):
 | 
			
		|||
    return result, (in_features, out_features, mp_group)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
 | 
			
		||||
                 enable_xetla, optimize_lm_head, enable_scale_search):
 | 
			
		||||
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
 | 
			
		||||
    from ipex_llm.transformers.low_bit_linear import LowBitLinear, \
 | 
			
		||||
        FP16Linear, BF16Linear, vLLMLowBitLinear, vLLMFP16Linear, vLLMBF16Linear
 | 
			
		||||
    if isinstance(module, ParallelLMHead):
 | 
			
		||||
        if qtype == ggml_tensor_qtype["fp16"]:
 | 
			
		||||
            new_linear = FP16Linear(
 | 
			
		||||
                in_features,
 | 
			
		||||
                out_features,
 | 
			
		||||
                module.bias is not None,
 | 
			
		||||
                mp_group=mp_group,
 | 
			
		||||
                optimize_lm_head=optimize_lm_head
 | 
			
		||||
            )
 | 
			
		||||
        elif qtype == ggml_tensor_qtype["bf16"]:
 | 
			
		||||
            new_linear = BF16Linear(
 | 
			
		||||
                in_features,
 | 
			
		||||
                out_features,
 | 
			
		||||
                module.bias is not None,
 | 
			
		||||
                mp_group=mp_group,
 | 
			
		||||
                optimize_lm_head=optimize_lm_head
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            new_linear = LowBitLinear(
 | 
			
		||||
                in_features,
 | 
			
		||||
                out_features,
 | 
			
		||||
                cur_qtype,
 | 
			
		||||
                module.bias is not None,
 | 
			
		||||
                mp_group=mp_group,
 | 
			
		||||
                enable_xetla=enable_xetla,
 | 
			
		||||
                optimize_lm_head=optimize_lm_head,
 | 
			
		||||
                enable_scale_search=enable_scale_search,
 | 
			
		||||
            )
 | 
			
		||||
    else:
 | 
			
		||||
        if qtype == ggml_tensor_qtype["fp16"]:
 | 
			
		||||
            new_linear = vLLMFP16Linear(
 | 
			
		||||
                in_features,
 | 
			
		||||
                out_features,
 | 
			
		||||
                module.bias is not None,
 | 
			
		||||
                mp_group=mp_group,
 | 
			
		||||
                optimize_lm_head=optimize_lm_head
 | 
			
		||||
            )
 | 
			
		||||
        elif qtype == ggml_tensor_qtype["bf16"]:
 | 
			
		||||
            new_linear = vLLMBF16Linear(
 | 
			
		||||
                in_features,
 | 
			
		||||
                out_features,
 | 
			
		||||
                module.bias is not None,
 | 
			
		||||
                mp_group=mp_group,
 | 
			
		||||
                optimize_lm_head=optimize_lm_head
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            new_linear = vLLMLowBitLinear(
 | 
			
		||||
                in_features,
 | 
			
		||||
                out_features,
 | 
			
		||||
                cur_qtype,
 | 
			
		||||
                module.bias is not None,
 | 
			
		||||
                mp_group=mp_group,
 | 
			
		||||
                enable_xetla=enable_xetla,
 | 
			
		||||
                optimize_lm_head=optimize_lm_head,
 | 
			
		||||
                enable_scale_search=enable_scale_search,
 | 
			
		||||
            )
 | 
			
		||||
    return new_linear
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
 | 
			
		||||
    from ipex_llm.transformers.low_bit_linear import get_block_size
 | 
			
		||||
    Q4_1 = get_block_size("asym_int4")
 | 
			
		||||
| 
						 | 
				
			
			@ -399,16 +465,27 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                    # check hidden size whether is a multiple of 256
 | 
			
		||||
                    cur_qtype = check_hidden_size(cur_qtype, in_features)
 | 
			
		||||
 | 
			
		||||
                    new_linear = LowBitLinear(
 | 
			
		||||
                        in_features,
 | 
			
		||||
                        out_features,
 | 
			
		||||
                        cur_qtype,
 | 
			
		||||
                        module.bias is not None,
 | 
			
		||||
                        mp_group=mp_group,
 | 
			
		||||
                        enable_xetla=enable_xetla,
 | 
			
		||||
                        optimize_lm_head=optimize_lm_head,
 | 
			
		||||
                        enable_scale_search=enable_scale_search,
 | 
			
		||||
                    )
 | 
			
		||||
                    if _USE_VLLM:
 | 
			
		||||
                        new_linear = convert_vllm(module,
 | 
			
		||||
                                                  qtype,
 | 
			
		||||
                                                  in_features,
 | 
			
		||||
                                                  out_features,
 | 
			
		||||
                                                  mp_group,
 | 
			
		||||
                                                  cur_qtype,
 | 
			
		||||
                                                  enable_xetla,
 | 
			
		||||
                                                  optimize_lm_head,
 | 
			
		||||
                                                  enable_scale_search)
 | 
			
		||||
                    else:
 | 
			
		||||
                        new_linear = LowBitLinear(
 | 
			
		||||
                            in_features,
 | 
			
		||||
                            out_features,
 | 
			
		||||
                            cur_qtype,
 | 
			
		||||
                            module.bias is not None,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                            enable_xetla=enable_xetla,
 | 
			
		||||
                            optimize_lm_head=optimize_lm_head,
 | 
			
		||||
                            enable_scale_search=enable_scale_search,
 | 
			
		||||
                        )
 | 
			
		||||
                    device = module.weight.data.device
 | 
			
		||||
                    # Copy the weights
 | 
			
		||||
                    paramsLowBit = FP4Params(data=module.weight.data,
 | 
			
		||||
| 
						 | 
				
			
			@ -427,13 +504,26 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                            .to(device)
 | 
			
		||||
                elif qtype == ggml_tensor_qtype["fp16"]:
 | 
			
		||||
                    module.to(torch.float16)
 | 
			
		||||
                    new_linear = FP16Linear(
 | 
			
		||||
                        in_features,
 | 
			
		||||
                        out_features,
 | 
			
		||||
                        module.bias is not None,
 | 
			
		||||
                        mp_group=mp_group,
 | 
			
		||||
                        optimize_lm_head=optimize_lm_head
 | 
			
		||||
                    )
 | 
			
		||||
                    if _USE_VLLM:
 | 
			
		||||
                        new_linear = convert_vllm(
 | 
			
		||||
                            module,
 | 
			
		||||
                            qtype,
 | 
			
		||||
                            in_features,
 | 
			
		||||
                            out_features,
 | 
			
		||||
                            mp_group,
 | 
			
		||||
                            None,
 | 
			
		||||
                            None,
 | 
			
		||||
                            optimize_lm_head,
 | 
			
		||||
                            None
 | 
			
		||||
                        )
 | 
			
		||||
                    else:
 | 
			
		||||
                        new_linear = FP16Linear(
 | 
			
		||||
                            in_features,
 | 
			
		||||
                            out_features,
 | 
			
		||||
                            module.bias is not None,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                            optimize_lm_head=optimize_lm_head
 | 
			
		||||
                        )
 | 
			
		||||
                    device = module.weight.data.device
 | 
			
		||||
                    from ipex_llm.transformers.utils import get_ipex_version
 | 
			
		||||
                    if get_ipex_version() < "2.1.10+xpu":
 | 
			
		||||
| 
						 | 
				
			
			@ -449,13 +539,26 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                            .to(device)
 | 
			
		||||
                elif qtype == ggml_tensor_qtype["bf16"]:
 | 
			
		||||
                    module.to(torch.bfloat16)
 | 
			
		||||
                    new_linear = BF16Linear(
 | 
			
		||||
                        in_features,
 | 
			
		||||
                        out_features,
 | 
			
		||||
                        module.bias is not None,
 | 
			
		||||
                        mp_group=mp_group,
 | 
			
		||||
                        optimize_lm_head=optimize_lm_head
 | 
			
		||||
                    )
 | 
			
		||||
                    if _USE_VLLM:
 | 
			
		||||
                        new_linear = convert_vllm(
 | 
			
		||||
                            module,
 | 
			
		||||
                            qtype,
 | 
			
		||||
                            in_features,
 | 
			
		||||
                            out_features,
 | 
			
		||||
                            mp_group,
 | 
			
		||||
                            None,
 | 
			
		||||
                            None,
 | 
			
		||||
                            optimize_lm_head,
 | 
			
		||||
                            None
 | 
			
		||||
                        )
 | 
			
		||||
                    else:
 | 
			
		||||
                        new_linear = BF16Linear(
 | 
			
		||||
                            in_features,
 | 
			
		||||
                            out_features,
 | 
			
		||||
                            module.bias is not None,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                            optimize_lm_head=optimize_lm_head
 | 
			
		||||
                        )
 | 
			
		||||
                    device = module.weight.data.device
 | 
			
		||||
                    # convert here
 | 
			
		||||
                    new_linear._parameters['weight'] = nn.Parameter(module.weight)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1009,3 +1009,38 @@ class BF16Linear(nn.Linear):
 | 
			
		|||
            result = result.reshape(*original_shape[:-1], result.shape[-1])
 | 
			
		||||
 | 
			
		||||
        return result.to(x.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class vLLMLowBitLinear(LowBitLinear):
 | 
			
		||||
    def __init__(self, input_features, output_features, qtype, bias=True,
 | 
			
		||||
                 conver_to_half=True, mp_group=None, enable_xetla=False,
 | 
			
		||||
                 optimize_lm_head=False, act_order=False,
 | 
			
		||||
                 enable_scale_search=False):
 | 
			
		||||
        super().__init__(input_features, output_features, qtype, bias, conver_to_half, mp_group,
 | 
			
		||||
                         enable_xetla, optimize_lm_head, act_order, enable_scale_search)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        result = super().forward(x)
 | 
			
		||||
        return result, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class vLLMFP16Linear(FP16Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, bias=True, mp_group=None, weight_type=1,
 | 
			
		||||
                 enable_xetla=False, optimize_lm_head=False):
 | 
			
		||||
        super().__init__(input_features, output_features, bias, mp_group, weight_type,
 | 
			
		||||
                         enable_xetla, optimize_lm_head)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        result = super().forward(x)
 | 
			
		||||
        return result, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class vLLMBF16Linear(BF16Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, bias=True, mp_group=None,
 | 
			
		||||
                 compute_dtype=None, enable_xetla=False, optimize_lm_head=False):
 | 
			
		||||
        super().__init__(input_features, output_features, bias, mp_group, compute_dtype,
 | 
			
		||||
                         enable_xetla, optimize_lm_head)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        result = super().forward(x)
 | 
			
		||||
        return result, None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -225,8 +225,8 @@ def _ipex_llm_convert(load_in_low_bit):
 | 
			
		|||
 | 
			
		||||
def get_load_function(low_bit):
 | 
			
		||||
    def _ipex_llm_load_model(self) -> None:
 | 
			
		||||
        _model_mlp_convert()
 | 
			
		||||
        _model_attention_convert()
 | 
			
		||||
        # _model_mlp_convert()
 | 
			
		||||
        # _model_attention_convert()
 | 
			
		||||
        _model_sample_convert()
 | 
			
		||||
 | 
			
		||||
        from vllm.utils import measure_device_memory
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue