LLM: add default torch_dtype for fp16. (#10124)

* set default torch_dtype for fp16.

* fix style.

* bug fix.

* update bug fix.
This commit is contained in:
Cengguang Zhang 2024-02-08 10:24:16 +08:00 committed by GitHub
parent 1aa0c623ce
commit 0cf6a12691

View file

@ -155,6 +155,7 @@ class _BaseAutoModelClass:
optimize_model = kwargs.pop("optimize_model", True)
user_quantization_config = kwargs.pop("quantization_config", None)
speculative = kwargs.pop("speculative", False)
torch_dtype = kwargs.pop("torch_dtype", None)
if user_quantization_config is not None and \
"BitsAndBytesConfig" in str(user_quantization_config.__class__):
@ -250,8 +251,19 @@ class _BaseAutoModelClass:
# load int x-bit
kwargs["low_cpu_mem_usage"] = True
# set default torch_dtype='auto'
kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto')
# set default torch_dtype='auto'.
# Note that when load_in_low_bit="fp16", set default torch_dtype=torch.float16
if load_in_low_bit == "fp16":
if torch_dtype is not None and torch_dtype != torch.float16:
invalidInputError(
False,
f"Please use torch_dtype=torch.float16"
f" when setting load_in_low_bit='fp16'."
)
else:
kwargs["torch_dtype"] = torch.float16
else:
kwargs["torch_dtype"] = torch_dtype or "auto"
# Avoid tensor parallel F.Linear Operations
if "pretraining_tp" in config_dict:
if "config" in kwargs: