diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 5cd679cf..06dc88a4 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -80,6 +80,9 @@ class _BaseAutoModelClass: kwargs["low_cpu_mem_usage"] = True # set default torch_dtype='auto' kwargs["torch_dtype"] = kwargs.get("torch_dtype", 'auto') + # Avoid tensor parallel F.Linear Operations + if "pretraining_tp" in config_dict: + kwargs["pretraining_tp"] = 1 q_k = load_in_low_bit if load_in_low_bit else "sym_int4" model = cls.load_convert(q_k, *args, **kwargs) else: