fix nf4 to cpu (#12722)
This commit is contained in:
		
							parent
							
								
									9aa4be8ced
								
							
						
					
					
						commit
						085974e307
					
				
					 1 changed files with 7 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -204,12 +204,15 @@ def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int):
 | 
			
		||||
 | 
			
		||||
    if qtype == NF4:
 | 
			
		||||
        invalidInputError(tensor.dtype == torch.bfloat16,
 | 
			
		||||
                          "NF4 Input tensor must be bfloat16")
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(tensor.dtype == torch.uint8,
 | 
			
		||||
                          "Input tensor must be uint8")
 | 
			
		||||
 | 
			
		||||
    invalidInputError(tensor.device == torch.device('cpu'),
 | 
			
		||||
                      "Input tensor must be uint8")
 | 
			
		||||
                      "Input tensor must be on cpu")
 | 
			
		||||
 | 
			
		||||
    src = ctypes.c_void_p(tensor.data.data_ptr())
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue