From b4147a97bb7fd6d54837cfb4da7bea0a15f18ec9 Mon Sep 17 00:00:00 2001 From: Zhicun <59141989+ivy-lv11@users.noreply.github.com> Date: Tue, 9 Apr 2024 17:50:33 +0800 Subject: [PATCH] Fix dtype mismatch error (#10609) * fix llama * fix * fix code style * add torch type in model.py --------- Co-authored-by: arda --- python/llm/src/ipex_llm/transformers/model.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index be3f13a5..44b2e0ad 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -295,6 +295,15 @@ class _BaseAutoModelClass: ) else: kwargs["torch_dtype"] = torch.float16 + elif load_in_low_bit == "bf16": + if torch_dtype is not None and torch_dtype != torch.bfloat16: + invalidInputError( + False, + f"Please use torch_dtype=torch.bfloat16" + f" when setting load_in_low_bit='bf16'." + ) + else: + kwargs["torch_dtype"] = torch.bfloat16 else: kwargs["torch_dtype"] = torch_dtype or "auto" # Avoid tensor parallel F.Linear Operations