Use new layout for xpu qlinear (#8896)
* use new layout for xpu qlinear * fix style
This commit is contained in:
parent
8bc1d8a17c
commit
c34400e6b0
2 changed files with 87 additions and 4 deletions
|
|
@ -999,6 +999,42 @@ _lib.ggml_dequantize_q4_0.argtypes = [
|
|||
_lib.ggml_quantize_q4_0.restype = None
|
||||
|
||||
|
||||
def ggml_q_format_convet_cpu2xpu(
|
||||
src: ctypes.c_void_p,
|
||||
dst: ctypes.c_void_p,
|
||||
n: ctypes.c_int,
|
||||
qtype: ctypes.c_int
|
||||
):
|
||||
_lib.ggml_q_format_convet_cpu2xpu(src, dst, n, qtype)
|
||||
|
||||
|
||||
_lib.ggml_q_format_convet_cpu2xpu.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
]
|
||||
_lib.ggml_q_format_convet_cpu2xpu.restype = None
|
||||
|
||||
|
||||
def ggml_q_format_convet_xpu2cpu(
|
||||
src: ctypes.c_void_p,
|
||||
dst: ctypes.c_void_p,
|
||||
n: ctypes.c_int,
|
||||
qtype: ctypes.c_int
|
||||
):
|
||||
_lib.ggml_q_format_convet_xpu2cpu(src, dst, n, qtype)
|
||||
|
||||
|
||||
_lib.ggml_q_format_convet_xpu2cpu.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int
|
||||
]
|
||||
_lib.ggml_q_format_convet_xpu2cpu.restype = None
|
||||
|
||||
|
||||
# def ggml_dequantize_nf4(
|
||||
# src: ctypes.c_void_p,
|
||||
# dst: ctypes.c_void_p,
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ import os
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, device, dtype, nn
|
||||
from operator import mul
|
||||
from functools import reduce
|
||||
|
||||
T = TypeVar("T", bound="torch.nn.Module")
|
||||
|
||||
|
|
@ -86,6 +88,38 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, device=None):
|
|||
return dst_tensor
|
||||
|
||||
|
||||
def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int):
|
||||
|
||||
invalidInputError(tensor.dtype == torch.uint8,
|
||||
"Input tensor must be uint8")
|
||||
|
||||
invalidInputError(tensor.device == torch.device('cpu'),
|
||||
"Input tensor must be uint8")
|
||||
|
||||
src = ctypes.c_void_p(tensor.data.data_ptr())
|
||||
|
||||
dst_tensor = torch.empty_like(tensor)
|
||||
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
||||
ggml.ggml_q_format_convet_cpu2xpu(src, dst, num_elem, qtype)
|
||||
return dst_tensor
|
||||
|
||||
|
||||
def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int):
|
||||
|
||||
invalidInputError(tensor.dtype == torch.uint8,
|
||||
"Input tensor must be uint8")
|
||||
|
||||
invalidInputError(tensor.device == torch.device('cpu'),
|
||||
"Input tensor must be uint8")
|
||||
|
||||
src = ctypes.c_void_p(tensor.data.data_ptr())
|
||||
|
||||
dst_tensor = torch.empty_like(tensor)
|
||||
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
||||
ggml.ggml_q_format_convet_xpu2cpu(src, dst, num_elem, qtype)
|
||||
return dst_tensor
|
||||
|
||||
|
||||
def ggml_int4_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int):
|
||||
invalidInputError(tensor.dtype == torch.uint8,
|
||||
"Input tensor must be uint8")
|
||||
|
|
@ -106,7 +140,6 @@ class FP4Params(torch.nn.Parameter):
|
|||
def __new__(cls,
|
||||
data=None,
|
||||
requires_grad=False,
|
||||
old_data=None,
|
||||
quantized=False,
|
||||
_shape=None,
|
||||
qtype=None):
|
||||
|
|
@ -154,7 +187,10 @@ class FP4Params(torch.nn.Parameter):
|
|||
return self.quantize(device.type)
|
||||
elif (device is not None and device.type == "xpu" and self.data.device.type == "cpu"):
|
||||
# enter xpu logic, compile linear_int4 extension at first time
|
||||
q_tensor = self.quantize(device) # tensor is cpu now
|
||||
self.quantize(device) # tensor is cpu now
|
||||
self.data = ggml_q_format_convet_cpu2xpu(self.data,
|
||||
reduce(mul, self._shape, 1),
|
||||
self.qtype)
|
||||
new_param = FP4Params(super().to(device=device,
|
||||
dtype=dtype,
|
||||
non_blocking=non_blocking),
|
||||
|
|
@ -163,6 +199,18 @@ class FP4Params(torch.nn.Parameter):
|
|||
_shape=self._shape,
|
||||
qtype=self.qtype)
|
||||
return new_param
|
||||
elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"):
|
||||
new_param = FP4Params(super().to(device=device,
|
||||
dtype=dtype,
|
||||
non_blocking=non_blocking),
|
||||
requires_grad=self.requires_grad,
|
||||
quantized=self.quantized,
|
||||
_shape=self._shape,
|
||||
qtype=self.qtype)
|
||||
new_param.data = ggml_q_format_convet_xpu2cpu(new_param.data,
|
||||
reduce(mul, new_param._shape, 1),
|
||||
new_param.qtype)
|
||||
return new_param
|
||||
else:
|
||||
new_param = FP4Params(super().to(device=device,
|
||||
dtype=dtype,
|
||||
|
|
@ -217,7 +265,6 @@ class LowBitLinear(nn.Linear):
|
|||
super().__init__(input_features, output_features, bias)
|
||||
self.weight = FP4Params(self.weight.data,
|
||||
requires_grad=False,
|
||||
old_data=self.weight.data,
|
||||
quantized=False, _shape=None, qtype=qtype)
|
||||
self.in_len = input_features
|
||||
self.out_len = output_features
|
||||
|
|
@ -249,7 +296,7 @@ class LowBitLinear(nn.Linear):
|
|||
if x_2d.shape[0] > 1 and x_2d.dtype == torch.float32:
|
||||
x_2d = x_2d.half()
|
||||
# input format of linear_q4.forward is 1: input, 2: weight
|
||||
result = linear_q4_0.forward(x_2d, x0, self.qtype)
|
||||
result = linear_q4_0.forward_new(x_2d, x0, self.qtype)
|
||||
new_shape = x_shape[:-1] + (self.out_len,)
|
||||
result = result.view(new_shape)
|
||||
if self.bias is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue