[LLM] Set torch_dtype default value to 'auto' for transformers low bit from_pretrained API
This commit is contained in:
parent
bbde423349
commit
ba42a6da63
1 changed files with 2 additions and 0 deletions
|
|
@ -65,6 +65,8 @@ class _BaseAutoModelClass:
|
|||
if load_in_4bit or load_in_low_bit:
|
||||
# load int x-bit
|
||||
kwargs["low_cpu_mem_usage"] = True
|
||||
# set default torch_dtype='auto'
|
||||
kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto')
|
||||
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
||||
model = cls.load_convert(q_k, *args, **kwargs)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in a new issue