[LLM] Set torch_dtype default value to 'auto' for transformers low bit from_pretrained API
This commit is contained in:
parent
bbde423349
commit
ba42a6da63
1 changed files with 2 additions and 0 deletions
|
|
@ -65,6 +65,8 @@ class _BaseAutoModelClass:
|
||||||
if load_in_4bit or load_in_low_bit:
|
if load_in_4bit or load_in_low_bit:
|
||||||
# load int x-bit
|
# load int x-bit
|
||||||
kwargs["low_cpu_mem_usage"] = True
|
kwargs["low_cpu_mem_usage"] = True
|
||||||
|
# set default torch_dtype='auto'
|
||||||
|
kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto')
|
||||||
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
||||||
model = cls.load_convert(q_k, *args, **kwargs)
|
model = cls.load_convert(q_k, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue