fix nf4 to cpu (#12722)
This commit is contained in:
parent
9aa4be8ced
commit
085974e307
1 changed files with 7 additions and 4 deletions
|
|
@ -204,12 +204,15 @@ def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int
|
||||||
|
|
||||||
|
|
||||||
def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int):
|
def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int):
|
||||||
|
if qtype == NF4:
|
||||||
invalidInputError(tensor.dtype == torch.uint8,
|
invalidInputError(tensor.dtype == torch.bfloat16,
|
||||||
"Input tensor must be uint8")
|
"NF4 Input tensor must be bfloat16")
|
||||||
|
else:
|
||||||
|
invalidInputError(tensor.dtype == torch.uint8,
|
||||||
|
"Input tensor must be uint8")
|
||||||
|
|
||||||
invalidInputError(tensor.device == torch.device('cpu'),
|
invalidInputError(tensor.device == torch.device('cpu'),
|
||||||
"Input tensor must be uint8")
|
"Input tensor must be on cpu")
|
||||||
|
|
||||||
src = ctypes.c_void_p(tensor.data.data_ptr())
|
src = ctypes.c_void_p(tensor.data.data_ptr())
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue