LLM: Disable transformer api pretraining_tp (#8645)
* disable pretraining_tp
This commit is contained in:
parent
6fc31bb4cf
commit
04c713ef06
1 changed files with 3 additions and 0 deletions
|
|
@ -80,6 +80,9 @@ class _BaseAutoModelClass:
|
|||
kwargs["low_cpu_mem_usage"] = True
|
||||
# set default torch_dtype='auto'
|
||||
kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto')
|
||||
# Avoid tensor parallel F.Linear Operations
|
||||
if "pretraining_tp" in config_dict:
|
||||
kwargs["pretraining_tp"] = 1
|
||||
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
||||
model = cls.load_convert(q_k, *args, **kwargs)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in a new issue