diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index e498987f..36e985df 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -93,7 +93,10 @@ class _BaseAutoModelClass: kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto') # Avoid tensor parallel F.Linear Operations if "pretraining_tp" in config_dict: - kwargs["pretraining_tp"] = 1 + if "config" in kwargs: + setattr(kwargs["config"], "pretraining_tp", 1) + else: + kwargs["pretraining_tp"] = 1 q_k = load_in_low_bit if load_in_low_bit else "sym_int4" model = cls.load_convert(q_k, optimize_model, *args, **kwargs) else: