diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index be3f13a5..44b2e0ad 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -295,6 +295,15 @@ class _BaseAutoModelClass: ) else: kwargs["torch_dtype"] = torch.float16 + elif load_in_low_bit == "bf16": + if torch_dtype is not None and torch_dtype != torch.bfloat16: + invalidInputError( + False, + f"Please use torch_dtype=torch.bfloat16" + f" when setting load_in_low_bit='bf16'." + ) + else: + kwargs["torch_dtype"] = torch.bfloat16 else: kwargs["torch_dtype"] = torch_dtype or "auto" # Avoid tensor parallel F.Linear Operations