Fix dtype mismatch error (#10609)
* fix llama * fix * fix code style * add torch type in model.py --------- Co-authored-by: arda <arda@arda-arc19.sh.intel.com>
This commit is contained in:
parent
f37a1f2a81
commit
b4147a97bb
1 changed files with 9 additions and 0 deletions
|
|
@ -295,6 +295,15 @@ class _BaseAutoModelClass:
|
|||
)
|
||||
else:
|
||||
kwargs["torch_dtype"] = torch.float16
|
||||
elif load_in_low_bit == "bf16":
|
||||
if torch_dtype is not None and torch_dtype != torch.bfloat16:
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Please use torch_dtype=torch.bfloat16"
|
||||
f" when setting load_in_low_bit='bf16'."
|
||||
)
|
||||
else:
|
||||
kwargs["torch_dtype"] = torch.bfloat16
|
||||
else:
|
||||
kwargs["torch_dtype"] = torch_dtype or "auto"
|
||||
# Avoid tensor parallel F.Linear Operations
|
||||
|
|
|
|||
Loading…
Reference in a new issue