From fc8c7904f0575da34d3f744678b1227e67a047a2 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Wed, 27 Mar 2024 14:18:45 +0800 Subject: [PATCH] LLM: fix torch_dtype setting of apply fp16 optimization through optimize_model (#10556) --- python/llm/src/ipex_llm/optimize.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/llm/src/ipex_llm/optimize.py b/python/llm/src/ipex_llm/optimize.py index ee1afc4d..dc199c00 100644 --- a/python/llm/src/ipex_llm/optimize.py +++ b/python/llm/src/ipex_llm/optimize.py @@ -237,9 +237,19 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_ warnings.warn("replace_embedding is deprecated and will be removed in a future version," " please use cpu_embedding instead.", FutureWarning) cpu_embedding = True + if low_bit == "fp16": + torch_dtype = kwargs.get("torch_dtype", None) + if torch_dtype is not None and torch_dtype != torch.float16: + invalidInputError(False, + "Please use torch_dtype=torch.float16 when setting low_bit='fp16'.") + else: + torch_dtype = torch.float16 + else: + torch_dtype = kwargs.get("torch_dtype", "auto") qtype = ggml_tensor_qtype[low_bit] model = ggml_convert_low_bit(model, qtype=qtype, + torch_dtype=torch_dtype, optimize_model=optimize_llm, modules_to_not_convert=modules_to_not_convert, cpu_embedding=cpu_embedding,