LLM: fix torch_dtype setting of apply fp16 optimization through optimize_model (#10556)

This commit is contained in:
binbin Deng 2024-03-27 14:18:45 +08:00 committed by GitHub
parent ea4bc450c4
commit fc8c7904f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -237,9 +237,19 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
warnings.warn("replace_embedding is deprecated and will be removed in a future version," warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
" please use cpu_embedding instead.", FutureWarning) " please use cpu_embedding instead.", FutureWarning)
cpu_embedding = True cpu_embedding = True
if low_bit == "fp16":
torch_dtype = kwargs.get("torch_dtype", None)
if torch_dtype is not None and torch_dtype != torch.float16:
invalidInputError(False,
"Please use torch_dtype=torch.float16 when setting low_bit='fp16'.")
else:
torch_dtype = torch.float16
else:
torch_dtype = kwargs.get("torch_dtype", "auto")
qtype = ggml_tensor_qtype[low_bit] qtype = ggml_tensor_qtype[low_bit]
model = ggml_convert_low_bit(model, model = ggml_convert_low_bit(model,
qtype=qtype, qtype=qtype,
torch_dtype=torch_dtype,
optimize_model=optimize_llm, optimize_model=optimize_llm,
modules_to_not_convert=modules_to_not_convert, modules_to_not_convert=modules_to_not_convert,
cpu_embedding=cpu_embedding, cpu_embedding=cpu_embedding,