remove unused code again (#12624)
This commit is contained in:
		
							parent
							
								
									46eeab4479
								
							
						
					
					
						commit
						c72a5db757
					
				
					 10 changed files with 19 additions and 171 deletions
				
			
		| 
						 | 
				
			
			@ -92,8 +92,7 @@ def train(
 | 
			
		|||
        load_in_low_bit="bf16",
 | 
			
		||||
        optimize_model=True,
 | 
			
		||||
        torch_dtype=torch.bfloat16,
 | 
			
		||||
        trust_remote_code=True,
 | 
			
		||||
        enable_xetla=False
 | 
			
		||||
        trust_remote_code=True
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    model = model.to("xpu")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -257,8 +257,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
 | 
			
		|||
                                 optimize_model=optimize_llm,
 | 
			
		||||
                                 modules_to_not_convert=modules_to_not_convert,
 | 
			
		||||
                                 cpu_embedding=cpu_embedding,
 | 
			
		||||
                                 lightweight_bmm=lightweight_bmm,
 | 
			
		||||
                                 enable_xetla=kwargs.pop("enable_xetla", False))
 | 
			
		||||
                                 lightweight_bmm=lightweight_bmm)
 | 
			
		||||
    # add save_low_bit to pretrained model dynamically
 | 
			
		||||
    import types
 | 
			
		||||
    model._bigdl_config = dict()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -232,7 +232,7 @@ def is_linear_module(module):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
 | 
			
		||||
                 enable_xetla, optimize_lm_head, enable_scale_search):
 | 
			
		||||
                 optimize_lm_head, enable_scale_search):
 | 
			
		||||
    from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
 | 
			
		||||
    from ipex_llm.transformers.low_bit_linear import LowBitLinear, \
 | 
			
		||||
        FP16Linear, BF16Linear, vLLMLowBitLinear, vLLMFP16Linear, vLLMBF16Linear
 | 
			
		||||
| 
						 | 
				
			
			@ -261,7 +261,6 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
 | 
			
		|||
                cur_qtype,
 | 
			
		||||
                module.bias is not None,
 | 
			
		||||
                mp_group=mp_group,
 | 
			
		||||
                enable_xetla=enable_xetla,
 | 
			
		||||
                optimize_lm_head=optimize_lm_head,
 | 
			
		||||
                enable_scale_search=enable_scale_search,
 | 
			
		||||
            )
 | 
			
		||||
| 
						 | 
				
			
			@ -289,7 +288,6 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
 | 
			
		|||
                cur_qtype,
 | 
			
		||||
                module.bias is not None,
 | 
			
		||||
                mp_group=mp_group,
 | 
			
		||||
                enable_xetla=enable_xetla,
 | 
			
		||||
                optimize_lm_head=optimize_lm_head,
 | 
			
		||||
                enable_scale_search=enable_scale_search,
 | 
			
		||||
            )
 | 
			
		||||
| 
						 | 
				
			
			@ -473,7 +471,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                                 prefix_name='',
 | 
			
		||||
                                 imatrix_data=None, embedding_qtype=None,
 | 
			
		||||
                                 model_config=None, torch_dtype=torch.float32,
 | 
			
		||||
                                 enable_xetla=False,
 | 
			
		||||
                                 mixed_precision=False,
 | 
			
		||||
                                 act_order=False,
 | 
			
		||||
                                 enable_scale_search=False,
 | 
			
		||||
| 
						 | 
				
			
			@ -523,7 +520,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                        qtype=qtype,
 | 
			
		||||
                        bias=has_bias,
 | 
			
		||||
                        mp_group=mp_group,
 | 
			
		||||
                        enable_xetla=enable_xetla,
 | 
			
		||||
                        optimize_lm_head=optimize_lm_head,
 | 
			
		||||
                        act_order=act_order,
 | 
			
		||||
                        enable_scale_search=enable_scale_search,
 | 
			
		||||
| 
						 | 
				
			
			@ -544,7 +540,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                                             _shape=(out_features, in_features),
 | 
			
		||||
                                             convert_shape_only=convert_shape_only,
 | 
			
		||||
                                             qtype=qtype,
 | 
			
		||||
                                             enable_xetla=enable_xetla,
 | 
			
		||||
                                             enable_scale_search=enable_scale_search).to(device)
 | 
			
		||||
                    new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                    if has_bias:
 | 
			
		||||
| 
						 | 
				
			
			@ -562,7 +557,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                            qtype=qtype,
 | 
			
		||||
                            bias=has_bias,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                            enable_xetla=enable_xetla,
 | 
			
		||||
                            optimize_lm_head=False,
 | 
			
		||||
                            act_order=act_order,
 | 
			
		||||
                            enable_scale_search=enable_scale_search,
 | 
			
		||||
| 
						 | 
				
			
			@ -581,7 +575,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                                                 qtype=cur_qtype,
 | 
			
		||||
                                                 imatrix=cur_imatrix,
 | 
			
		||||
                                                 in_features=in_features,
 | 
			
		||||
                                                 enable_xetla=enable_xetla,
 | 
			
		||||
                                                 enable_scale_search=enable_scale_search).to(device)
 | 
			
		||||
                    else:
 | 
			
		||||
                        new_linear = vLLMLowBitLinear(
 | 
			
		||||
| 
						 | 
				
			
			@ -590,7 +583,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                            qtype=qtype,
 | 
			
		||||
                            bias=has_bias,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                            enable_xetla=enable_xetla,
 | 
			
		||||
                            optimize_lm_head=False,
 | 
			
		||||
                            act_order=act_order,
 | 
			
		||||
                            enable_scale_search=enable_scale_search,
 | 
			
		||||
| 
						 | 
				
			
			@ -609,7 +601,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                                                 _shape=(out_features, in_features),
 | 
			
		||||
                                                 convert_shape_only=convert_shape_only,
 | 
			
		||||
                                                 qtype=qtype,
 | 
			
		||||
                                                 enable_xetla=enable_xetla,
 | 
			
		||||
                                                 enable_scale_search=enable_scale_search).to(device)
 | 
			
		||||
                    new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                    if has_bias:
 | 
			
		||||
| 
						 | 
				
			
			@ -639,7 +630,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                                                  out_features,
 | 
			
		||||
                                                  mp_group,
 | 
			
		||||
                                                  cur_qtype,
 | 
			
		||||
                                                  enable_xetla,
 | 
			
		||||
                                                  optimize_lm_head,
 | 
			
		||||
                                                  enable_scale_search)
 | 
			
		||||
                    else:
 | 
			
		||||
| 
						 | 
				
			
			@ -649,7 +639,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                            cur_qtype,
 | 
			
		||||
                            module.bias is not None,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                            enable_xetla=enable_xetla,
 | 
			
		||||
                            optimize_lm_head=optimize_lm_head,
 | 
			
		||||
                            enable_scale_search=enable_scale_search,
 | 
			
		||||
                        )
 | 
			
		||||
| 
						 | 
				
			
			@ -663,7 +652,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                                             qtype=cur_qtype,
 | 
			
		||||
                                             imatrix=cur_imatrix,
 | 
			
		||||
                                             in_features=in_features,
 | 
			
		||||
                                             enable_xetla=enable_xetla,
 | 
			
		||||
                                             enable_scale_search=enable_scale_search).to(device)
 | 
			
		||||
                    new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                    if module.bias is not None:
 | 
			
		||||
| 
						 | 
				
			
			@ -762,7 +750,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                embedding_qtype=embedding_qtype,
 | 
			
		||||
                model_config=model_config,
 | 
			
		||||
                torch_dtype=torch_dtype,
 | 
			
		||||
                enable_xetla=enable_xetla,
 | 
			
		||||
                mixed_precision=mixed_precision,
 | 
			
		||||
                act_order=act_order,
 | 
			
		||||
                enable_scale_search=enable_scale_search,
 | 
			
		||||
| 
						 | 
				
			
			@ -1094,7 +1081,6 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
                         lightweight_bmm=False, torch_dtype="auto",
 | 
			
		||||
                         imatrix_data=None,
 | 
			
		||||
                         embedding_qtype=None,
 | 
			
		||||
                         enable_xetla=False,
 | 
			
		||||
                         mixed_precision=False):
 | 
			
		||||
    if qtype in ggml_tensor_qtype.values():
 | 
			
		||||
        index = list(ggml_tensor_qtype.values()).index(qtype)
 | 
			
		||||
| 
						 | 
				
			
			@ -1138,7 +1124,6 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
            embedding_qtype=embedding_qtype,
 | 
			
		||||
            model_config=model_config,
 | 
			
		||||
            torch_dtype=torch_dtype,
 | 
			
		||||
            enable_xetla=enable_xetla,
 | 
			
		||||
            mixed_precision=mixed_precision,
 | 
			
		||||
            act_order=act_order,
 | 
			
		||||
            enable_scale_search=enable_scale_search,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -92,113 +92,6 @@ RTN_DTYPE = {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# For sym_int4
 | 
			
		||||
# The ggml_weight is col major and packs two rows at a stride of Q4_0//2.
 | 
			
		||||
#
 | 
			
		||||
# The returning weight is row major and packs two rows at a stride of 16//2.
 | 
			
		||||
# 16 is the tile_size_y used in mm_xetla, so that we can do something like
 | 
			
		||||
# new_weight_tile = concat(weight_tile & 0x0F, weight_tile >> 4).
 | 
			
		||||
#
 | 
			
		||||
# A more complex packing strategy is to permute the weight so that the
 | 
			
		||||
# new_weight_tile is directly VNNI packed, but I did not find significant
 | 
			
		||||
# performance improvement.
 | 
			
		||||
#
 | 
			
		||||
# Note this format cannot be used directly in IPEX-LLM's mm_int4, which expects
 | 
			
		||||
# row major but packing two consecutive columns.
 | 
			
		||||
#
 | 
			
		||||
# For fp8, just remove the scales (which are all ones) and transpose
 | 
			
		||||
def ggml_xpu_to_ipex_llm_xetla(ggml_weight, weight_shape, qtype):
 | 
			
		||||
    if qtype == ggml_tensor_qtype["sym_int4"]:
 | 
			
		||||
        from ipex_llm.transformers.low_bit_linear import get_block_size
 | 
			
		||||
        Q4_0 = get_block_size("sym_int4")
 | 
			
		||||
 | 
			
		||||
        n, k = weight_shape
 | 
			
		||||
        ggml_weight_only = ggml_weight[:n*k//2]
 | 
			
		||||
        ggml_scales = ggml_weight[n*k//2:]
 | 
			
		||||
 | 
			
		||||
        qweight = ggml_weight_only.clone()
 | 
			
		||||
        scales = ggml_scales.view(torch.float16).clone()
 | 
			
		||||
 | 
			
		||||
        qweight_0 = qweight & 0x0F
 | 
			
		||||
        qweight_1 = qweight >> 4
 | 
			
		||||
 | 
			
		||||
        qweight_0 = qweight_0.reshape(n, -1, Q4_0//2)
 | 
			
		||||
        qweight_1 = qweight_1.reshape(n, -1, Q4_0//2)
 | 
			
		||||
        qweight = torch.cat([qweight_0, qweight_1], dim=-1)
 | 
			
		||||
        qweight = qweight.reshape(n, k//16, 2, 8)
 | 
			
		||||
        qweight = qweight.bitwise_left_shift(
 | 
			
		||||
            torch.tensor([0, 4], dtype=torch.uint8, device=ggml_weight.device).reshape(1, 1, 2, 1))
 | 
			
		||||
 | 
			
		||||
        qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
 | 
			
		||||
        qweight = qweight.reshape(n, k//2)
 | 
			
		||||
        qweight = qweight.transpose(0, 1).contiguous()
 | 
			
		||||
 | 
			
		||||
        scales = scales.reshape(n, k//Q4_0).transpose(0, 1).contiguous()
 | 
			
		||||
 | 
			
		||||
        # 119 is the value of 0x77
 | 
			
		||||
        zeros = torch.ones([k//Q4_0, n//2], dtype=torch.uint8, device=ggml_weight.device) * (119)
 | 
			
		||||
 | 
			
		||||
        qweight_bytes = qweight.view(torch.uint8).view(-1)
 | 
			
		||||
        scales_bytes = scales.view(torch.uint8).view(-1)
 | 
			
		||||
        zeros_bytes = zeros.view(torch.uint8).view(-1)
 | 
			
		||||
 | 
			
		||||
        weight = torch.concat([qweight_bytes, zeros_bytes, scales_bytes], dim=0)
 | 
			
		||||
    elif qtype == ggml_tensor_qtype["fp8_e5m2"]:
 | 
			
		||||
        n, k = weight_shape
 | 
			
		||||
        weight = ggml_weight[:n*k].view(n, k).transpose(0, 1).contiguous()
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False, f"Unsupported qtype {qtype}")
 | 
			
		||||
    return weight
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ipex_llm_xetla_to_ggml_xpu(xetla_weight, weight_shape, qtype):
 | 
			
		||||
    from ipex_llm.transformers.low_bit_linear import get_block_size
 | 
			
		||||
    if qtype == ggml_tensor_qtype["sym_int4"]:
 | 
			
		||||
        Q4_0 = get_block_size("sym_int4")
 | 
			
		||||
        n, k = weight_shape
 | 
			
		||||
        weight_size = n*k//2
 | 
			
		||||
        zeros_size = n*k//Q4_0//2
 | 
			
		||||
        scales_size = n*k//Q4_0 * 2
 | 
			
		||||
        xetla_weight_only = xetla_weight[:weight_size]
 | 
			
		||||
        scales_start = weight_size + zeros_size
 | 
			
		||||
        xetla_scales = xetla_weight[scales_start:scales_start+scales_size]
 | 
			
		||||
 | 
			
		||||
        qweight = xetla_weight_only.clone()
 | 
			
		||||
        scales = xetla_scales.view(torch.float16).clone()
 | 
			
		||||
 | 
			
		||||
        qweight_0 = qweight & 0x0F
 | 
			
		||||
        qweight_1 = qweight >> 4
 | 
			
		||||
        qweight_0 = qweight_0.reshape(-1, 8, n)
 | 
			
		||||
        qweight_1 = qweight_1.reshape(-1, 8, n)
 | 
			
		||||
        qweight = torch.cat([qweight_0, qweight_1], dim=1)
 | 
			
		||||
 | 
			
		||||
        qweight = qweight.reshape(k, n).transpose(0, 1).contiguous().reshape(n, k//Q4_0,
 | 
			
		||||
                                                                             2, Q4_0//2)
 | 
			
		||||
        qweight = qweight.bitwise_left_shift(
 | 
			
		||||
            torch.tensor([0, 4], dtype=torch.uint8,
 | 
			
		||||
                         device=xetla_weight_only.device).reshape(1, 1, 2, 1))
 | 
			
		||||
 | 
			
		||||
        qweight = torch.bitwise_or(qweight[:, :, 0, :], qweight[:, :, 1, :])
 | 
			
		||||
        qweight = qweight.reshape(n, k//2)
 | 
			
		||||
 | 
			
		||||
        scales = scales.reshape(k//Q4_0, n).transpose(0, 1).contiguous()
 | 
			
		||||
 | 
			
		||||
        qweight_bytes = qweight.view(torch.uint8).view(-1)
 | 
			
		||||
        scales_bytes = scales.view(torch.uint8).view(-1)
 | 
			
		||||
        weight = torch.concat([qweight_bytes, scales_bytes], dim=0)
 | 
			
		||||
    elif qtype == ggml_tensor_qtype["fp8_e5m2"]:
 | 
			
		||||
        Q8_0 = get_block_size("fp8_e5m2")
 | 
			
		||||
        n, k = weight_shape
 | 
			
		||||
        qweight = xetla_weight[:n*k].transpose(0, 1).contiguous()
 | 
			
		||||
        scales = torch.ones([n*k//Q8_0], dtype=torch.float, device=xetla_weight.device)
 | 
			
		||||
        qweight_bytes = qweight.view(torch.uint8).view(-1)
 | 
			
		||||
        scales_bytes = scales.view(torch.uint8).view(-1)
 | 
			
		||||
        weight = torch.concat([qweight_bytes, scales_bytes], dim=0)
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False, f"Unsupported qtype {qtype}")
 | 
			
		||||
    return weight
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_block_size(qtype: str):
 | 
			
		||||
    return ggml.ggml_qk_size(ggml_tensor_qtype[qtype])
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -422,7 +315,6 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
                qtype=None,
 | 
			
		||||
                imatrix=None,
 | 
			
		||||
                in_features=None,
 | 
			
		||||
                enable_xetla=False,
 | 
			
		||||
                enable_scale_search=False):
 | 
			
		||||
        if data is None:
 | 
			
		||||
            data = torch.empty(0)
 | 
			
		||||
| 
						 | 
				
			
			@ -435,7 +327,6 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
        self.convert_shape_only = convert_shape_only
 | 
			
		||||
        self.imatrix = imatrix
 | 
			
		||||
        self.in_features = in_features
 | 
			
		||||
        self.enable_xetla = enable_xetla
 | 
			
		||||
        self.enable_scale_search = enable_scale_search
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -529,8 +420,6 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
            self.data = ggml_q_format_convet_cpu2xpu(self.data,
 | 
			
		||||
                                                     reduce(mul, self._shape, 1),
 | 
			
		||||
                                                     self.qtype)
 | 
			
		||||
            if self.enable_xetla:
 | 
			
		||||
                self.data = ggml_xpu_to_ipex_llm_xetla(self.data, self._shape, self.qtype)
 | 
			
		||||
            new_param = FP4Params(super().to(device=device,
 | 
			
		||||
                                             dtype=dtype,
 | 
			
		||||
                                             non_blocking=non_blocking),
 | 
			
		||||
| 
						 | 
				
			
			@ -538,12 +427,7 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
                                  quantized=self.quantized,
 | 
			
		||||
                                  _shape=self._shape,
 | 
			
		||||
                                  qtype=self.qtype,
 | 
			
		||||
                                  enable_xetla=self.enable_xetla,
 | 
			
		||||
                                  enable_scale_search=self.enable_scale_search)
 | 
			
		||||
            if self.enable_xetla:
 | 
			
		||||
                device_type = get_xpu_device_type(new_param.data)
 | 
			
		||||
                invalidInputError(device_type == "pvc",
 | 
			
		||||
                                  f"xetla is only supported on PVC, but got {device_type}")
 | 
			
		||||
            return new_param
 | 
			
		||||
        elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"):
 | 
			
		||||
            new_param = FP4Params(super().to(device=device,
 | 
			
		||||
| 
						 | 
				
			
			@ -553,13 +437,7 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
                                  quantized=self.quantized,
 | 
			
		||||
                                  _shape=self._shape,
 | 
			
		||||
                                  qtype=self.qtype,
 | 
			
		||||
                                  enable_xetla=self.enable_xetla,
 | 
			
		||||
                                  enable_scale_search=self.enable_scale_search)
 | 
			
		||||
            if self.enable_xetla:
 | 
			
		||||
                ggml_xpu = ipex_llm_xetla_to_ggml_xpu(new_param.data,
 | 
			
		||||
                                                      new_param._shape,
 | 
			
		||||
                                                      new_param.qtype)
 | 
			
		||||
            else:
 | 
			
		||||
            ggml_xpu = new_param.data
 | 
			
		||||
            new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu,
 | 
			
		||||
                                                          reduce(mul, new_param._shape, 1),
 | 
			
		||||
| 
						 | 
				
			
			@ -573,7 +451,6 @@ class FP4Params(torch.nn.Parameter):
 | 
			
		|||
                                  quantized=self.quantized,
 | 
			
		||||
                                  _shape=self._shape,
 | 
			
		||||
                                  qtype=self.qtype,
 | 
			
		||||
                                  enable_xetla=self.enable_xetla,
 | 
			
		||||
                                  enable_scale_search=self.enable_scale_search)
 | 
			
		||||
            return new_param
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -691,14 +568,13 @@ class MatMulLowBitCPU(torch.autograd.Function):
 | 
			
		|||
 | 
			
		||||
class LowBitLinear(nn.Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, qtype, bias=True,
 | 
			
		||||
                 conver_to_half=True, mp_group=None, enable_xetla=False,
 | 
			
		||||
                 conver_to_half=True, mp_group=None,
 | 
			
		||||
                 optimize_lm_head=False, act_order=False,
 | 
			
		||||
                 enable_scale_search=False):
 | 
			
		||||
        super().__init__(input_features, output_features, bias)
 | 
			
		||||
        self.weight = FP4Params(self.weight.data,
 | 
			
		||||
                                requires_grad=False,
 | 
			
		||||
                                quantized=False, _shape=None, qtype=qtype,
 | 
			
		||||
                                enable_xetla=enable_xetla,
 | 
			
		||||
                                enable_scale_search=enable_scale_search)
 | 
			
		||||
        self.in_len = input_features
 | 
			
		||||
        self.out_len = output_features
 | 
			
		||||
| 
						 | 
				
			
			@ -708,7 +584,6 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
        self.conver_to_half = conver_to_half
 | 
			
		||||
        self.mp_group = mp_group
 | 
			
		||||
        self.compute_dtype = None  # only for training
 | 
			
		||||
        self.enable_xetla = enable_xetla
 | 
			
		||||
        self.optimize_lm_head = optimize_lm_head
 | 
			
		||||
        self.device = None  # detected only once in the first forward
 | 
			
		||||
        # empty cache before and after lm_head at first token (by default on arc)
 | 
			
		||||
| 
						 | 
				
			
			@ -799,9 +674,6 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
                                                       self.weight.data,
 | 
			
		||||
                                                       self.weight.qtype,
 | 
			
		||||
                                                       input_seq_size)
 | 
			
		||||
            elif self.enable_xetla:
 | 
			
		||||
                x_2d = x_2d.half()
 | 
			
		||||
                result = xe_linear.mm_xetla(x_2d, self.weight.data, self.qtype)
 | 
			
		||||
            else:
 | 
			
		||||
                # inference path
 | 
			
		||||
                # current workaround to reduce first token latency of fp32 input
 | 
			
		||||
| 
						 | 
				
			
			@ -880,8 +752,7 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
 | 
			
		||||
class FP16Linear(nn.Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, bias=True,
 | 
			
		||||
                 mp_group=None, weight_type=1, enable_xetla=False,
 | 
			
		||||
                 optimize_lm_head=False):
 | 
			
		||||
                 mp_group=None, weight_type=1, optimize_lm_head=False):
 | 
			
		||||
        super().__init__(input_features, output_features, bias)
 | 
			
		||||
        self.in_len = input_features
 | 
			
		||||
        self.out_len = output_features
 | 
			
		||||
| 
						 | 
				
			
			@ -894,7 +765,6 @@ class FP16Linear(nn.Linear):
 | 
			
		|||
        # weigh_type = 3 means weight has been transposed by esimd method
 | 
			
		||||
        self.weight_type = 1
 | 
			
		||||
        self.optimize_lm_head = optimize_lm_head
 | 
			
		||||
        self.enable_xetla = enable_xetla
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        # only work for GPU
 | 
			
		||||
| 
						 | 
				
			
			@ -1010,8 +880,7 @@ class FP16Linear(nn.Linear):
 | 
			
		|||
 | 
			
		||||
class BF16Linear(nn.Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, bias=True,
 | 
			
		||||
                 mp_group=None, compute_dtype=None, enable_xetla=False,
 | 
			
		||||
                 optimize_lm_head=False):
 | 
			
		||||
                 mp_group=None, compute_dtype=None, optimize_lm_head=False):
 | 
			
		||||
        super().__init__(input_features, output_features, bias)
 | 
			
		||||
        self.in_len = input_features
 | 
			
		||||
        self.out_len = output_features
 | 
			
		||||
| 
						 | 
				
			
			@ -1021,7 +890,6 @@ class BF16Linear(nn.Linear):
 | 
			
		|||
        self.mp_group = mp_group
 | 
			
		||||
        self.compute_dtype = compute_dtype
 | 
			
		||||
        self.optimize_lm_head = optimize_lm_head
 | 
			
		||||
        self.enable_xetla = enable_xetla
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        if self.optimize_lm_head:
 | 
			
		||||
| 
						 | 
				
			
			@ -1050,11 +918,11 @@ class BF16Linear(nn.Linear):
 | 
			
		|||
 | 
			
		||||
class vLLMLowBitLinear(LowBitLinear):
 | 
			
		||||
    def __init__(self, input_features, output_features, qtype, bias=True,
 | 
			
		||||
                 conver_to_half=True, mp_group=None, enable_xetla=False,
 | 
			
		||||
                 conver_to_half=True, mp_group=None,
 | 
			
		||||
                 optimize_lm_head=False, act_order=False,
 | 
			
		||||
                 enable_scale_search=False):
 | 
			
		||||
        super().__init__(input_features, output_features, qtype, bias, conver_to_half, mp_group,
 | 
			
		||||
                         enable_xetla, optimize_lm_head, act_order, enable_scale_search)
 | 
			
		||||
                         optimize_lm_head, act_order, enable_scale_search)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        result = super().forward(x)
 | 
			
		||||
| 
						 | 
				
			
			@ -1063,9 +931,9 @@ class vLLMLowBitLinear(LowBitLinear):
 | 
			
		|||
 | 
			
		||||
class vLLMFP16Linear(FP16Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, bias=True, mp_group=None, weight_type=1,
 | 
			
		||||
                 enable_xetla=False, optimize_lm_head=False):
 | 
			
		||||
                 optimize_lm_head=False):
 | 
			
		||||
        super().__init__(input_features, output_features, bias, mp_group, weight_type,
 | 
			
		||||
                         enable_xetla, optimize_lm_head)
 | 
			
		||||
                         optimize_lm_head)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        result = super().forward(x)
 | 
			
		||||
| 
						 | 
				
			
			@ -1074,9 +942,9 @@ class vLLMFP16Linear(FP16Linear):
 | 
			
		|||
 | 
			
		||||
class vLLMBF16Linear(BF16Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, bias=True, mp_group=None,
 | 
			
		||||
                 compute_dtype=None, enable_xetla=False, optimize_lm_head=False):
 | 
			
		||||
                 compute_dtype=None, optimize_lm_head=False):
 | 
			
		||||
        super().__init__(input_features, output_features, bias, mp_group, compute_dtype,
 | 
			
		||||
                         enable_xetla, optimize_lm_head)
 | 
			
		||||
                         optimize_lm_head)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        result = super().forward(x)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -448,7 +448,6 @@ class _BaseAutoModelClass:
 | 
			
		|||
        mixed_precision = kwargs.pop("mixed_precision", False)
 | 
			
		||||
        if embedding_qtype is not None:
 | 
			
		||||
            embedding_qtype = ggml_tensor_qtype[embedding_qtype]
 | 
			
		||||
        enable_xetla = kwargs.pop("enable_xetla", False)
 | 
			
		||||
        _args = copy.deepcopy(args)
 | 
			
		||||
        _kwargs = copy.deepcopy(kwargs)
 | 
			
		||||
        awq_config = None
 | 
			
		||||
| 
						 | 
				
			
			@ -518,7 +517,6 @@ class _BaseAutoModelClass:
 | 
			
		|||
                                     torch_dtype=kwargs.get("torch_dtype", 'auto'),
 | 
			
		||||
                                     imatrix_data=imatrix_data,
 | 
			
		||||
                                     embedding_qtype=embedding_qtype,
 | 
			
		||||
                                     enable_xetla=enable_xetla,
 | 
			
		||||
                                     mixed_precision=mixed_precision)
 | 
			
		||||
 | 
			
		||||
        if disk_embedding:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -67,7 +67,7 @@ def baichuan_mlp_forward(
 | 
			
		|||
) -> torch.Tensor:
 | 
			
		||||
    x_2d = x.view(-1, x.shape[-1])
 | 
			
		||||
    qtype = getattr(self.gate_proj, "qtype", None)
 | 
			
		||||
    if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
 | 
			
		||||
    if mlp_fusion_check(x_2d, qtype, self.training):
 | 
			
		||||
        import xe_linear
 | 
			
		||||
        if not x_2d.is_contiguous():
 | 
			
		||||
            x_2d = x_2d.contiguous()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -380,7 +380,7 @@ def mixtral_mlp_forward(
 | 
			
		|||
    routing_weights
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    qtype = getattr(self.w1, "qtype", None)
 | 
			
		||||
    if mlp_fusion_check(x, qtype, self.training) and not self.w1.enable_xetla:
 | 
			
		||||
    if mlp_fusion_check(x, qtype, self.training):
 | 
			
		||||
        import xe_linear
 | 
			
		||||
        return self.w2(xe_linear.mlp_forward_xpu(
 | 
			
		||||
            x, self.w1.weight.data, self.w3.weight.data,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -259,7 +259,7 @@ def qwen_attention_forward_registered(
 | 
			
		|||
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    x_2d = x.view(-1, x.shape[-1])
 | 
			
		||||
    qtype = getattr(self.w1, "qtype", None)
 | 
			
		||||
    if mlp_fusion_check(x_2d, qtype, self.training) and not self.w1.enable_xetla:
 | 
			
		||||
    if mlp_fusion_check(x_2d, qtype, self.training):
 | 
			
		||||
        import xe_linear
 | 
			
		||||
        if not x_2d.is_contiguous():
 | 
			
		||||
            x_2d = x_2d.contiguous()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -612,7 +612,7 @@ def qwen2_mlp_forward(
 | 
			
		|||
) -> torch.Tensor:
 | 
			
		||||
    x_2d = x.view(-1, x.shape[-1])
 | 
			
		||||
    qtype = getattr(self.gate_proj, "qtype", None)
 | 
			
		||||
    if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
 | 
			
		||||
    if mlp_fusion_check(x_2d, qtype, self.training):
 | 
			
		||||
        import xe_linear
 | 
			
		||||
        return self.down_proj(xe_linear.mlp_forward_xpu(
 | 
			
		||||
            x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -337,8 +337,7 @@ def use_decoding_fast_path(proj,
 | 
			
		|||
        return False
 | 
			
		||||
    if bs != 1:
 | 
			
		||||
        return False
 | 
			
		||||
    if proj.enable_xetla:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    if device in ["uhd"]:
 | 
			
		||||
        return False
 | 
			
		||||
    return True
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue