diff --git a/docs/readthedocs/source/doc/Orca/Overview/distributed-training-inference.md b/docs/readthedocs/source/doc/Orca/Overview/distributed-training-inference.md index 957aa3d9..a8b4b5a5 100644 --- a/docs/readthedocs/source/doc/Orca/Overview/distributed-training-inference.md +++ b/docs/readthedocs/source/doc/Orca/Overview/distributed-training-inference.md @@ -304,3 +304,43 @@ result_shards = est.predict(shards) The input to `predict` methods can be an *XShards*, or a *numpy array*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details. View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-openvino-estimator) for more details. + +### 7. MPI Estimator +The Orca MPI Estimator is to run distributed training job based on MPI. + +#### Preparation: +* Configure password-less ssh from the master node (the one you'll launch training from) to all other nodes. + +* All hosts have the same working directory. +* All hosts have the same Python environment in the same location. + +#### Train +Then the user may create a MPI Estimator as follows: +```python +from bigdl.orca.learn.mpi import MPIEstimator + +est = MPIEstimator(model_creator=model_creator, + optimizer_creator=optimizer_creator, + loss_creator=None, + metrics=None, + scheduler_creator=None, + config=config, + init_func=init, # Init the distributed environment for MPI if any + hosts=hosts, + workers_per_node=workers_per_node, + env=None) +``` +Then the user can perform distributed model training as follows: +```python +# read spark Dataframe +df = spark.read.parquet("data.parquet") + +# distributed model training +est.fit(data=df, epochs=1, batch_size=4, feature_col="feature", label_cols="label") + +``` +The input to `fit` methods can be an Spark Dataframe, or a callable function to return a `torch.utils.data.DataLoader`. + +View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-mpi-mpi-estimator) for more details. + + diff --git a/docs/readthedocs/source/doc/PythonAPI/Orca/orca.rst b/docs/readthedocs/source/doc/PythonAPI/Orca/orca.rst index a2a45a27..bc525446 100644 --- a/docs/readthedocs/source/doc/PythonAPI/Orca/orca.rst +++ b/docs/readthedocs/source/doc/PythonAPI/Orca/orca.rst @@ -77,3 +77,11 @@ orca.learn.openvino.estimator :members: :undoc-members: :show-inheritance: + +orca.learn.mpi.mpi_estimator +------------------------------ + +.. autoclass:: bigdl.orca.learn.mpi.MPIEstimator + :members: + :undoc-members: + :show-inheritance: