diff --git a/python/llm/src/ipex_llm/optimize.py b/python/llm/src/ipex_llm/optimize.py index ee1afc4d..dc199c00 100644 --- a/python/llm/src/ipex_llm/optimize.py +++ b/python/llm/src/ipex_llm/optimize.py @@ -237,9 +237,19 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_ warnings.warn("replace_embedding is deprecated and will be removed in a future version," " please use cpu_embedding instead.", FutureWarning) cpu_embedding = True + if low_bit == "fp16": + torch_dtype = kwargs.get("torch_dtype", None) + if torch_dtype is not None and torch_dtype != torch.float16: + invalidInputError(False, + "Please use torch_dtype=torch.float16 when setting low_bit='fp16'.") + else: + torch_dtype = torch.float16 + else: + torch_dtype = kwargs.get("torch_dtype", "auto") qtype = ggml_tensor_qtype[low_bit] model = ggml_convert_low_bit(model, qtype=qtype, + torch_dtype=torch_dtype, optimize_model=optimize_llm, modules_to_not_convert=modules_to_not_convert, cpu_embedding=cpu_embedding,