LLM: fix pretraining_tp argument issue. (#9281)
This commit is contained in:
parent
6b2a32eba2
commit
44b5fcc190
1 changed files with 4 additions and 1 deletions
|
|
@ -93,7 +93,10 @@ class _BaseAutoModelClass:
|
||||||
kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto')
|
kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto')
|
||||||
# Avoid tensor parallel F.Linear Operations
|
# Avoid tensor parallel F.Linear Operations
|
||||||
if "pretraining_tp" in config_dict:
|
if "pretraining_tp" in config_dict:
|
||||||
kwargs["pretraining_tp"] = 1
|
if "config" in kwargs:
|
||||||
|
setattr(kwargs["config"], "pretraining_tp", 1)
|
||||||
|
else:
|
||||||
|
kwargs["pretraining_tp"] = 1
|
||||||
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
|
||||||
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
|
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue