Use new layout for xpu qlinear (#8896)

* use new layout for xpu qlinear

* fix style
This commit is contained in:
Yang Wang 2023-09-07 12:55:33 +08:00 committed by GitHub
parent 8bc1d8a17c
commit c34400e6b0
2 changed files with 87 additions and 4 deletions

View file

@ -999,6 +999,42 @@ _lib.ggml_dequantize_q4_0.argtypes = [
_lib.ggml_quantize_q4_0.restype = None
def ggml_q_format_convet_cpu2xpu(
src: ctypes.c_void_p,
dst: ctypes.c_void_p,
n: ctypes.c_int,
qtype: ctypes.c_int
):
_lib.ggml_q_format_convet_cpu2xpu(src, dst, n, qtype)
_lib.ggml_q_format_convet_cpu2xpu.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_int,
]
_lib.ggml_q_format_convet_cpu2xpu.restype = None
def ggml_q_format_convet_xpu2cpu(
src: ctypes.c_void_p,
dst: ctypes.c_void_p,
n: ctypes.c_int,
qtype: ctypes.c_int
):
_lib.ggml_q_format_convet_xpu2cpu(src, dst, n, qtype)
_lib.ggml_q_format_convet_xpu2cpu.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_int
]
_lib.ggml_q_format_convet_xpu2cpu.restype = None
# def ggml_dequantize_nf4(
# src: ctypes.c_void_p,
# dst: ctypes.c_void_p,

View file

@ -47,6 +47,8 @@ import os
import torch
import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
from operator import mul
from functools import reduce
T = TypeVar("T", bound="torch.nn.Module")
@ -86,6 +88,38 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, device=None):
return dst_tensor
def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int):
invalidInputError(tensor.dtype == torch.uint8,
"Input tensor must be uint8")
invalidInputError(tensor.device == torch.device('cpu'),
"Input tensor must be uint8")
src = ctypes.c_void_p(tensor.data.data_ptr())
dst_tensor = torch.empty_like(tensor)
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
ggml.ggml_q_format_convet_cpu2xpu(src, dst, num_elem, qtype)
return dst_tensor
def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int):
invalidInputError(tensor.dtype == torch.uint8,
"Input tensor must be uint8")
invalidInputError(tensor.device == torch.device('cpu'),
"Input tensor must be uint8")
src = ctypes.c_void_p(tensor.data.data_ptr())
dst_tensor = torch.empty_like(tensor)
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
ggml.ggml_q_format_convet_xpu2cpu(src, dst, num_elem, qtype)
return dst_tensor
def ggml_int4_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int):
invalidInputError(tensor.dtype == torch.uint8,
"Input tensor must be uint8")
@ -106,7 +140,6 @@ class FP4Params(torch.nn.Parameter):
def __new__(cls,
data=None,
requires_grad=False,
old_data=None,
quantized=False,
_shape=None,
qtype=None):
@ -154,7 +187,10 @@ class FP4Params(torch.nn.Parameter):
return self.quantize(device.type)
elif (device is not None and device.type == "xpu" and self.data.device.type == "cpu"):
# enter xpu logic, compile linear_int4 extension at first time
q_tensor = self.quantize(device) # tensor is cpu now
self.quantize(device) # tensor is cpu now
self.data = ggml_q_format_convet_cpu2xpu(self.data,
reduce(mul, self._shape, 1),
self.qtype)
new_param = FP4Params(super().to(device=device,
dtype=dtype,
non_blocking=non_blocking),
@ -163,6 +199,18 @@ class FP4Params(torch.nn.Parameter):
_shape=self._shape,
qtype=self.qtype)
return new_param
elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"):
new_param = FP4Params(super().to(device=device,
dtype=dtype,
non_blocking=non_blocking),
requires_grad=self.requires_grad,
quantized=self.quantized,
_shape=self._shape,
qtype=self.qtype)
new_param.data = ggml_q_format_convet_xpu2cpu(new_param.data,
reduce(mul, new_param._shape, 1),
new_param.qtype)
return new_param
else:
new_param = FP4Params(super().to(device=device,
dtype=dtype,
@ -217,7 +265,6 @@ class LowBitLinear(nn.Linear):
super().__init__(input_features, output_features, bias)
self.weight = FP4Params(self.weight.data,
requires_grad=False,
old_data=self.weight.data,
quantized=False, _shape=None, qtype=qtype)
self.in_len = input_features
self.out_len = output_features
@ -249,7 +296,7 @@ class LowBitLinear(nn.Linear):
if x_2d.shape[0] > 1 and x_2d.dtype == torch.float32:
x_2d = x_2d.half()
# input format of linear_q4.forward is 1: input, 2: weight
result = linear_q4_0.forward(x_2d, x0, self.qtype)
result = linear_q4_0.forward_new(x_2d, x0, self.qtype)
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.bias is not None: