Fix Loader issue with dtype fp16 (#10907)

This commit is contained in:
Guancheng Fu 2024-04-29 10:16:02 +08:00 committed by GitHub
parent c9fac8c26b
commit fbcd7bc737
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -59,6 +59,8 @@ def load_model(
model_kwargs["trust_remote_code"] = True model_kwargs["trust_remote_code"] = True
if low_bit == "bf16": if low_bit == "bf16":
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": torch.bfloat16}) model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": torch.bfloat16})
elif low_bit == "fp16":
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": torch.float16})
else: else:
model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": 'auto'}) model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": 'auto'})