Use new layout for xpu qlinear (#8896)
* use new layout for xpu qlinear * fix style
This commit is contained in:
parent
8bc1d8a17c
commit
c34400e6b0
2 changed files with 87 additions and 4 deletions
|
|
@ -999,6 +999,42 @@ _lib.ggml_dequantize_q4_0.argtypes = [
|
||||||
_lib.ggml_quantize_q4_0.restype = None
|
_lib.ggml_quantize_q4_0.restype = None
|
||||||
|
|
||||||
|
|
||||||
|
def ggml_q_format_convet_cpu2xpu(
|
||||||
|
src: ctypes.c_void_p,
|
||||||
|
dst: ctypes.c_void_p,
|
||||||
|
n: ctypes.c_int,
|
||||||
|
qtype: ctypes.c_int
|
||||||
|
):
|
||||||
|
_lib.ggml_q_format_convet_cpu2xpu(src, dst, n, qtype)
|
||||||
|
|
||||||
|
|
||||||
|
_lib.ggml_q_format_convet_cpu2xpu.argtypes = [
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.c_int,
|
||||||
|
ctypes.c_int,
|
||||||
|
]
|
||||||
|
_lib.ggml_q_format_convet_cpu2xpu.restype = None
|
||||||
|
|
||||||
|
|
||||||
|
def ggml_q_format_convet_xpu2cpu(
|
||||||
|
src: ctypes.c_void_p,
|
||||||
|
dst: ctypes.c_void_p,
|
||||||
|
n: ctypes.c_int,
|
||||||
|
qtype: ctypes.c_int
|
||||||
|
):
|
||||||
|
_lib.ggml_q_format_convet_xpu2cpu(src, dst, n, qtype)
|
||||||
|
|
||||||
|
|
||||||
|
_lib.ggml_q_format_convet_xpu2cpu.argtypes = [
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.c_int,
|
||||||
|
ctypes.c_int
|
||||||
|
]
|
||||||
|
_lib.ggml_q_format_convet_xpu2cpu.restype = None
|
||||||
|
|
||||||
|
|
||||||
# def ggml_dequantize_nf4(
|
# def ggml_dequantize_nf4(
|
||||||
# src: ctypes.c_void_p,
|
# src: ctypes.c_void_p,
|
||||||
# dst: ctypes.c_void_p,
|
# dst: ctypes.c_void_p,
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,8 @@ import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor, device, dtype, nn
|
from torch import Tensor, device, dtype, nn
|
||||||
|
from operator import mul
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
T = TypeVar("T", bound="torch.nn.Module")
|
T = TypeVar("T", bound="torch.nn.Module")
|
||||||
|
|
||||||
|
|
@ -86,6 +88,38 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, device=None):
|
||||||
return dst_tensor
|
return dst_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int):
|
||||||
|
|
||||||
|
invalidInputError(tensor.dtype == torch.uint8,
|
||||||
|
"Input tensor must be uint8")
|
||||||
|
|
||||||
|
invalidInputError(tensor.device == torch.device('cpu'),
|
||||||
|
"Input tensor must be uint8")
|
||||||
|
|
||||||
|
src = ctypes.c_void_p(tensor.data.data_ptr())
|
||||||
|
|
||||||
|
dst_tensor = torch.empty_like(tensor)
|
||||||
|
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
||||||
|
ggml.ggml_q_format_convet_cpu2xpu(src, dst, num_elem, qtype)
|
||||||
|
return dst_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int):
|
||||||
|
|
||||||
|
invalidInputError(tensor.dtype == torch.uint8,
|
||||||
|
"Input tensor must be uint8")
|
||||||
|
|
||||||
|
invalidInputError(tensor.device == torch.device('cpu'),
|
||||||
|
"Input tensor must be uint8")
|
||||||
|
|
||||||
|
src = ctypes.c_void_p(tensor.data.data_ptr())
|
||||||
|
|
||||||
|
dst_tensor = torch.empty_like(tensor)
|
||||||
|
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
||||||
|
ggml.ggml_q_format_convet_xpu2cpu(src, dst, num_elem, qtype)
|
||||||
|
return dst_tensor
|
||||||
|
|
||||||
|
|
||||||
def ggml_int4_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int):
|
def ggml_int4_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int):
|
||||||
invalidInputError(tensor.dtype == torch.uint8,
|
invalidInputError(tensor.dtype == torch.uint8,
|
||||||
"Input tensor must be uint8")
|
"Input tensor must be uint8")
|
||||||
|
|
@ -106,7 +140,6 @@ class FP4Params(torch.nn.Parameter):
|
||||||
def __new__(cls,
|
def __new__(cls,
|
||||||
data=None,
|
data=None,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
old_data=None,
|
|
||||||
quantized=False,
|
quantized=False,
|
||||||
_shape=None,
|
_shape=None,
|
||||||
qtype=None):
|
qtype=None):
|
||||||
|
|
@ -154,7 +187,10 @@ class FP4Params(torch.nn.Parameter):
|
||||||
return self.quantize(device.type)
|
return self.quantize(device.type)
|
||||||
elif (device is not None and device.type == "xpu" and self.data.device.type == "cpu"):
|
elif (device is not None and device.type == "xpu" and self.data.device.type == "cpu"):
|
||||||
# enter xpu logic, compile linear_int4 extension at first time
|
# enter xpu logic, compile linear_int4 extension at first time
|
||||||
q_tensor = self.quantize(device) # tensor is cpu now
|
self.quantize(device) # tensor is cpu now
|
||||||
|
self.data = ggml_q_format_convet_cpu2xpu(self.data,
|
||||||
|
reduce(mul, self._shape, 1),
|
||||||
|
self.qtype)
|
||||||
new_param = FP4Params(super().to(device=device,
|
new_param = FP4Params(super().to(device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
non_blocking=non_blocking),
|
non_blocking=non_blocking),
|
||||||
|
|
@ -163,6 +199,18 @@ class FP4Params(torch.nn.Parameter):
|
||||||
_shape=self._shape,
|
_shape=self._shape,
|
||||||
qtype=self.qtype)
|
qtype=self.qtype)
|
||||||
return new_param
|
return new_param
|
||||||
|
elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"):
|
||||||
|
new_param = FP4Params(super().to(device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
non_blocking=non_blocking),
|
||||||
|
requires_grad=self.requires_grad,
|
||||||
|
quantized=self.quantized,
|
||||||
|
_shape=self._shape,
|
||||||
|
qtype=self.qtype)
|
||||||
|
new_param.data = ggml_q_format_convet_xpu2cpu(new_param.data,
|
||||||
|
reduce(mul, new_param._shape, 1),
|
||||||
|
new_param.qtype)
|
||||||
|
return new_param
|
||||||
else:
|
else:
|
||||||
new_param = FP4Params(super().to(device=device,
|
new_param = FP4Params(super().to(device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
@ -217,7 +265,6 @@ class LowBitLinear(nn.Linear):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
self.weight = FP4Params(self.weight.data,
|
self.weight = FP4Params(self.weight.data,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
old_data=self.weight.data,
|
|
||||||
quantized=False, _shape=None, qtype=qtype)
|
quantized=False, _shape=None, qtype=qtype)
|
||||||
self.in_len = input_features
|
self.in_len = input_features
|
||||||
self.out_len = output_features
|
self.out_len = output_features
|
||||||
|
|
@ -249,7 +296,7 @@ class LowBitLinear(nn.Linear):
|
||||||
if x_2d.shape[0] > 1 and x_2d.dtype == torch.float32:
|
if x_2d.shape[0] > 1 and x_2d.dtype == torch.float32:
|
||||||
x_2d = x_2d.half()
|
x_2d = x_2d.half()
|
||||||
# input format of linear_q4.forward is 1: input, 2: weight
|
# input format of linear_q4.forward is 1: input, 2: weight
|
||||||
result = linear_q4_0.forward(x_2d, x0, self.qtype)
|
result = linear_q4_0.forward_new(x_2d, x0, self.qtype)
|
||||||
new_shape = x_shape[:-1] + (self.out_len,)
|
new_shape = x_shape[:-1] + (self.out_len,)
|
||||||
result = result.view(new_shape)
|
result = result.view(new_shape)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue