Fix dtype mismatch error (#10609)

* fix llama

* fix

* fix code style

* add torch type in model.py

---------

Co-authored-by: arda <arda@arda-arc19.sh.intel.com>
This commit is contained in:
Zhicun 2024-04-09 17:50:33 +08:00 committed by GitHub
parent f37a1f2a81
commit b4147a97bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -295,6 +295,15 @@ class _BaseAutoModelClass:
)
else:
kwargs["torch_dtype"] = torch.float16
elif load_in_low_bit == "bf16":
if torch_dtype is not None and torch_dtype != torch.bfloat16:
invalidInputError(
False,
f"Please use torch_dtype=torch.bfloat16"
f" when setting load_in_low_bit='bf16'."
)
else:
kwargs["torch_dtype"] = torch.bfloat16
else:
kwargs["torch_dtype"] = torch_dtype or "auto"
# Avoid tensor parallel F.Linear Operations