LLM: Disable transformer api pretraining_tp (#8645)

* disable pretraining_tp
This commit is contained in:
Zhao Changmin 2023-08-02 11:26:01 +08:00 committed by GitHub
parent 6fc31bb4cf
commit 04c713ef06

View file

@ -80,6 +80,9 @@ class _BaseAutoModelClass:
kwargs["low_cpu_mem_usage"] = True kwargs["low_cpu_mem_usage"] = True
# set default torch_dtype='auto' # set default torch_dtype='auto'
kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto') kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto')
# Avoid tensor parallel F.Linear Operations
if "pretraining_tp" in config_dict:
kwargs["pretraining_tp"] = 1
q_k = load_in_low_bit if load_in_low_bit else "sym_int4" q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
model = cls.load_convert(q_k, *args, **kwargs) model = cls.load_convert(q_k, *args, **kwargs)
else: else: