diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index 7bc0f7ce..070395be 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -81,7 +81,12 @@ Q4_K = ggml_tensor_qtype["q4_k"] Q6_K = ggml_tensor_qtype["q6_k"] Q5_K = ggml_tensor_qtype["q5_k"] FP6_K = ggml_tensor_qtype["fp6_k"] +SYM_INT4_RTN = ggml_tensor_qtype["sym_int4_rtn"] SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"] +RTN_DTYPE = { + SYM_INT4_RTN: torch.uint8, + SYM_INT8_RTN: torch.int8, +} # For sym_int4 @@ -217,8 +222,8 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, f"Last dim of input tensor must be multiple of {QK}") dst_size = (n // QK) * block_size_in_bytes - if qtype in [SYM_INT8_RTN]: - dst_tensor = torch.empty(dst_size, dtype=torch.int8, + if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]: + dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype], device=device) scale = torch.empty(n // k, dtype=torch.float32, device=device) @@ -230,11 +235,11 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) hist = (ctypes.c_int64 * 16)() if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]: - if qtype in [SYM_INT8_RTN]: + if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]: scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float)) ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n, k, hist, enable_scale_search) - dst_tensor = dst_tensor.reshape_as(tensor) + dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK) return dst_tensor, scale.type(torch.float16) else: ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search) diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index eb11bcef..cec84411 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -76,7 +76,7 @@ class _BaseAutoModelClass: # for intel_npu_acceleration_library >= 1.1.0 from intel_npu_acceleration_library.dtypes import int8, int4 qtype_map = { - 'sym_int4': int4, + 'sym_int4': "sym_int4_rtn", 'sym_int8': "sym_int8_rtn", 'fp16': torch.half, 'fp32': torch.float, @@ -119,7 +119,7 @@ class _BaseAutoModelClass: from intel_npu_acceleration_library.compiler import create_npu_kernels with torch.no_grad(): optimize_llm(model) - if qtype == "sym_int8_rtn": + if qtype in ["sym_int8_rtn", "sym_int4_rtn"]: cls.load_convert(qtype, model, *args, **kwargs) else: if not qtype.is_floating_point: