diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 651c63f1..022190e6 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -155,6 +155,7 @@ class _BaseAutoModelClass: optimize_model = kwargs.pop("optimize_model", True) user_quantization_config = kwargs.pop("quantization_config", None) speculative = kwargs.pop("speculative", False) + torch_dtype = kwargs.pop("torch_dtype", None) if user_quantization_config is not None and \ "BitsAndBytesConfig" in str(user_quantization_config.__class__): @@ -250,8 +251,19 @@ class _BaseAutoModelClass: # load int x-bit kwargs["low_cpu_mem_usage"] = True - # set default torch_dtype='auto' - kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto') + # set default torch_dtype='auto'. + # Note that when load_in_low_bit="fp16", set default torch_dtype=torch.float16 + if load_in_low_bit == "fp16": + if torch_dtype is not None and torch_dtype != torch.float16: + invalidInputError( + False, + f"Please use torch_dtype=torch.float16" + f" when setting load_in_low_bit='fp16'." + ) + else: + kwargs["torch_dtype"] = torch.float16 + else: + kwargs["torch_dtype"] = torch_dtype or "auto" # Avoid tensor parallel F.Linear Operations if "pretraining_tp" in config_dict: if "config" in kwargs: