fix nf4 to cpu (#12722)

This commit is contained in:
Yishuo Wang 2025-01-21 09:23:22 +08:00 committed by GitHub
parent 9aa4be8ced
commit 085974e307
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -204,12 +204,15 @@ def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int
def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int): def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int):
if qtype == NF4:
invalidInputError(tensor.dtype == torch.uint8, invalidInputError(tensor.dtype == torch.bfloat16,
"Input tensor must be uint8") "NF4 Input tensor must be bfloat16")
else:
invalidInputError(tensor.dtype == torch.uint8,
"Input tensor must be uint8")
invalidInputError(tensor.device == torch.device('cpu'), invalidInputError(tensor.device == torch.device('cpu'),
"Input tensor must be uint8") "Input tensor must be on cpu")
src = ctypes.c_void_p(tensor.data.data_ptr()) src = ctypes.c_void_p(tensor.data.data_ptr())