diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index 8aa002ee..8a8105fa 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -534,6 +534,9 @@ class _BaseAutoModelClass: :param pretrained_model_name_or_path: str value, Path to load the optimized model ckpt. :param optimize_model: boolean value, Whether to further optimize the low_bit llm model. Default to be True. + :param pipeline_parallel_stages: int value, the number of GPUs allocated for + pipeline parallel. Default to be ``1``. Please set pipeline_parallel_stages > 1 + to run pipeline parallel inference on multiple GPUs. :return: a model instance """ @@ -580,6 +583,8 @@ class _BaseAutoModelClass: embedding_qtype = kwargs.pop("embedding_qtype", None) sharded_metadata = None + pipeline_parallel_stages = kwargs.pop("pipeline_parallel_stages", 1) + config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path) bigdl_transformers_low_bit = config_dict.pop("bigdl_transformers_low_bit", False) bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True) @@ -750,6 +755,16 @@ class _BaseAutoModelClass: # rwkv model linear layers has been rescaled if model.config.model_type == "rwkv": model.rwkv.layers_are_rescaled = True + + if pipeline_parallel_stages > 1: + from .pipeline_parallel import pipeline_parallel, pipeline_parallel_generate + model = pipeline_parallel(model, pipeline_parallel_stages) + import types + # add pipeline_parallel_generate to pretrained model dynamically + model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate, + model) + torch.distributed.barrier() + return model