Fix Loader issue with dtype fp16 (#10907)
This commit is contained in:
parent
c9fac8c26b
commit
fbcd7bc737
1 changed files with 2 additions and 0 deletions
|
|
@ -59,6 +59,8 @@ def load_model(
|
||||||
model_kwargs["trust_remote_code"] = True
|
model_kwargs["trust_remote_code"] = True
|
||||||
if low_bit == "bf16":
|
if low_bit == "bf16":
|
||||||
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": torch.bfloat16})
|
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": torch.bfloat16})
|
||||||
|
elif low_bit == "fp16":
|
||||||
|
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": torch.float16})
|
||||||
else:
|
else:
|
||||||
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": 'auto'})
|
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": 'auto'})
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue