LLM: fix torch_dtype setting of apply fp16 optimization through optimize_model (#10556)
This commit is contained in:
parent
ea4bc450c4
commit
fc8c7904f0
1 changed files with 10 additions and 0 deletions
|
|
@ -237,9 +237,19 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
|
||||||
warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
|
warnings.warn("replace_embedding is deprecated and will be removed in a future version,"
|
||||||
" please use cpu_embedding instead.", FutureWarning)
|
" please use cpu_embedding instead.", FutureWarning)
|
||||||
cpu_embedding = True
|
cpu_embedding = True
|
||||||
|
if low_bit == "fp16":
|
||||||
|
torch_dtype = kwargs.get("torch_dtype", None)
|
||||||
|
if torch_dtype is not None and torch_dtype != torch.float16:
|
||||||
|
invalidInputError(False,
|
||||||
|
"Please use torch_dtype=torch.float16 when setting low_bit='fp16'.")
|
||||||
|
else:
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
else:
|
||||||
|
torch_dtype = kwargs.get("torch_dtype", "auto")
|
||||||
qtype = ggml_tensor_qtype[low_bit]
|
qtype = ggml_tensor_qtype[low_bit]
|
||||||
model = ggml_convert_low_bit(model,
|
model = ggml_convert_low_bit(model,
|
||||||
qtype=qtype,
|
qtype=qtype,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
optimize_model=optimize_llm,
|
optimize_model=optimize_llm,
|
||||||
modules_to_not_convert=modules_to_not_convert,
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
cpu_embedding=cpu_embedding,
|
cpu_embedding=cpu_embedding,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue