LLM: fix pretraining_tp argument issue. (#9281)
This commit is contained in:
parent
6b2a32eba2
commit
44b5fcc190
1 changed files with 4 additions and 1 deletions
|
|
@ -93,6 +93,9 @@ class _BaseAutoModelClass:
|
|||
kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto')
|
||||
# Avoid tensor parallel F.Linear Operations
|
||||
if "pretraining_tp" in config_dict:
|
||||
if "config" in kwargs:
|
||||
setattr(kwargs["config"], "pretraining_tp", 1)
|
||||
else:
|
||||
kwargs["pretraining_tp"] = 1
|
||||
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
||||
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue