LLM: fix pretraining_tp argument issue. (#9281)

This commit is contained in:
Cengguang Zhang 2023-10-26 18:43:58 +08:00 committed by GitHub
parent 6b2a32eba2
commit 44b5fcc190

View file

@ -93,6 +93,9 @@ class _BaseAutoModelClass:
kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto')
# Avoid tensor parallel F.Linear Operations
if "pretraining_tp" in config_dict:
if "config" in kwargs:
setattr(kwargs["config"], "pretraining_tp", 1)
else:
kwargs["pretraining_tp"] = 1
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)