LLM: add default torch_dtype for fp16. (#10124)
* set default torch_dtype for fp16. * fix style. * bug fix. * update bug fix.
This commit is contained in:
parent
1aa0c623ce
commit
0cf6a12691
1 changed files with 14 additions and 2 deletions
|
|
@ -155,6 +155,7 @@ class _BaseAutoModelClass:
|
|||
optimize_model = kwargs.pop("optimize_model", True)
|
||||
user_quantization_config = kwargs.pop("quantization_config", None)
|
||||
speculative = kwargs.pop("speculative", False)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
if user_quantization_config is not None and \
|
||||
"BitsAndBytesConfig" in str(user_quantization_config.__class__):
|
||||
|
|
@ -250,8 +251,19 @@ class _BaseAutoModelClass:
|
|||
|
||||
# load int x-bit
|
||||
kwargs["low_cpu_mem_usage"] = True
|
||||
# set default torch_dtype='auto'
|
||||
kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto')
|
||||
# set default torch_dtype='auto'.
|
||||
# Note that when load_in_low_bit="fp16", set default torch_dtype=torch.float16
|
||||
if load_in_low_bit == "fp16":
|
||||
if torch_dtype is not None and torch_dtype != torch.float16:
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Please use torch_dtype=torch.float16"
|
||||
f" when setting load_in_low_bit='fp16'."
|
||||
)
|
||||
else:
|
||||
kwargs["torch_dtype"] = torch.float16
|
||||
else:
|
||||
kwargs["torch_dtype"] = torch_dtype or "auto"
|
||||
# Avoid tensor parallel F.Linear Operations
|
||||
if "pretraining_tp" in config_dict:
|
||||
if "config" in kwargs:
|
||||
|
|
|
|||
Loading…
Reference in a new issue