LLM: Disable transformer api pretraining_tp (#8645)
* disable pretraining_tp
This commit is contained in:
parent
6fc31bb4cf
commit
04c713ef06
1 changed files with 3 additions and 0 deletions
|
|
@ -80,6 +80,9 @@ class _BaseAutoModelClass:
|
||||||
kwargs["low_cpu_mem_usage"] = True
|
kwargs["low_cpu_mem_usage"] = True
|
||||||
# set default torch_dtype='auto'
|
# set default torch_dtype='auto'
|
||||||
kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto')
|
kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto')
|
||||||
|
# Avoid tensor parallel F.Linear Operations
|
||||||
|
if "pretraining_tp" in config_dict:
|
||||||
|
kwargs["pretraining_tp"] = 1
|
||||||
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
||||||
model = cls.load_convert(q_k, *args, **kwargs)
|
model = cls.load_convert(q_k, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue