From c34400e6b0231b563e36aedd50b9149f4d4ea0f0 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Thu, 7 Sep 2023 12:55:33 +0800 Subject: [PATCH] Use new layout for xpu qlinear (#8896) * use new layout for xpu qlinear * fix style --- .../bigdl/llm/ggml/model/llama/llama_cpp.py | 36 ++++++++++++ .../bigdl/llm/transformers/low_bit_linear.py | 55 +++++++++++++++++-- 2 files changed, 87 insertions(+), 4 deletions(-) diff --git a/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py b/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py index ed62e5f7..14f3bf8a 100644 --- a/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py +++ b/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py @@ -999,6 +999,42 @@ _lib.ggml_dequantize_q4_0.argtypes = [ _lib.ggml_quantize_q4_0.restype = None +def ggml_q_format_convet_cpu2xpu( + src: ctypes.c_void_p, + dst: ctypes.c_void_p, + n: ctypes.c_int, + qtype: ctypes.c_int +): + _lib.ggml_q_format_convet_cpu2xpu(src, dst, n, qtype) + + +_lib.ggml_q_format_convet_cpu2xpu.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_int, +] +_lib.ggml_q_format_convet_cpu2xpu.restype = None + + +def ggml_q_format_convet_xpu2cpu( + src: ctypes.c_void_p, + dst: ctypes.c_void_p, + n: ctypes.c_int, + qtype: ctypes.c_int +): + _lib.ggml_q_format_convet_xpu2cpu(src, dst, n, qtype) + + +_lib.ggml_q_format_convet_xpu2cpu.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_int +] +_lib.ggml_q_format_convet_xpu2cpu.restype = None + + # def ggml_dequantize_nf4( # src: ctypes.c_void_p, # dst: ctypes.c_void_p, diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 0668c7c9..241f6109 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -47,6 +47,8 @@ import os import torch import torch.nn.functional as F from torch import Tensor, device, dtype, nn +from operator import mul +from functools import reduce T = TypeVar("T", bound="torch.nn.Module") @@ -86,6 +88,38 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, device=None): return dst_tensor +def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int): + + invalidInputError(tensor.dtype == torch.uint8, + "Input tensor must be uint8") + + invalidInputError(tensor.device == torch.device('cpu'), + "Input tensor must be uint8") + + src = ctypes.c_void_p(tensor.data.data_ptr()) + + dst_tensor = torch.empty_like(tensor) + dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) + ggml.ggml_q_format_convet_cpu2xpu(src, dst, num_elem, qtype) + return dst_tensor + + +def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int): + + invalidInputError(tensor.dtype == torch.uint8, + "Input tensor must be uint8") + + invalidInputError(tensor.device == torch.device('cpu'), + "Input tensor must be uint8") + + src = ctypes.c_void_p(tensor.data.data_ptr()) + + dst_tensor = torch.empty_like(tensor) + dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) + ggml.ggml_q_format_convet_xpu2cpu(src, dst, num_elem, qtype) + return dst_tensor + + def ggml_int4_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int): invalidInputError(tensor.dtype == torch.uint8, "Input tensor must be uint8") @@ -106,7 +140,6 @@ class FP4Params(torch.nn.Parameter): def __new__(cls, data=None, requires_grad=False, - old_data=None, quantized=False, _shape=None, qtype=None): @@ -154,7 +187,10 @@ class FP4Params(torch.nn.Parameter): return self.quantize(device.type) elif (device is not None and device.type == "xpu" and self.data.device.type == "cpu"): # enter xpu logic, compile linear_int4 extension at first time - q_tensor = self.quantize(device) # tensor is cpu now + self.quantize(device) # tensor is cpu now + self.data = ggml_q_format_convet_cpu2xpu(self.data, + reduce(mul, self._shape, 1), + self.qtype) new_param = FP4Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), @@ -163,6 +199,18 @@ class FP4Params(torch.nn.Parameter): _shape=self._shape, qtype=self.qtype) return new_param + elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"): + new_param = FP4Params(super().to(device=device, + dtype=dtype, + non_blocking=non_blocking), + requires_grad=self.requires_grad, + quantized=self.quantized, + _shape=self._shape, + qtype=self.qtype) + new_param.data = ggml_q_format_convet_xpu2cpu(new_param.data, + reduce(mul, new_param._shape, 1), + new_param.qtype) + return new_param else: new_param = FP4Params(super().to(device=device, dtype=dtype, @@ -217,7 +265,6 @@ class LowBitLinear(nn.Linear): super().__init__(input_features, output_features, bias) self.weight = FP4Params(self.weight.data, requires_grad=False, - old_data=self.weight.data, quantized=False, _shape=None, qtype=qtype) self.in_len = input_features self.out_len = output_features @@ -249,7 +296,7 @@ class LowBitLinear(nn.Linear): if x_2d.shape[0] > 1 and x_2d.dtype == torch.float32: x_2d = x_2d.half() # input format of linear_q4.forward is 1: input, 2: weight - result = linear_q4_0.forward(x_2d, x0, self.qtype) + result = linear_q4_0.forward_new(x_2d, x0, self.qtype) new_shape = x_shape[:-1] + (self.out_len,) result = result.view(new_shape) if self.bias is not None: