[LLM] Set torch_dtype default value to 'auto' for transformers low bit from_pretrained API

This commit is contained in:
Yuwen Hu 2023-07-21 17:55:00 +08:00 committed by GitHub
parent bbde423349
commit ba42a6da63

View file

@ -65,6 +65,8 @@ class _BaseAutoModelClass:
if load_in_4bit or load_in_low_bit:
# load int x-bit
kwargs["low_cpu_mem_usage"] = True
# set default torch_dtype='auto'
kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto')
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
model = cls.load_convert(q_k, *args, **kwargs)
else: