Update MPI Estimator to support Pytorch IPEX training (#8303)
* update * update * update * update * update * update with comments * update * update * style * style * add doc * style * style
This commit is contained in:
parent
35fdf94031
commit
5d90ca2dac
2 changed files with 48 additions and 0 deletions
|
|
@ -304,3 +304,43 @@ result_shards = est.predict(shards)
|
||||||
The input to `predict` methods can be an *XShards*, or a *numpy array*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details.
|
The input to `predict` methods can be an *XShards*, or a *numpy array*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details.
|
||||||
|
|
||||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-openvino-estimator) for more details.
|
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-openvino-estimator) for more details.
|
||||||
|
|
||||||
|
### 7. MPI Estimator
|
||||||
|
The Orca MPI Estimator is to run distributed training job based on MPI.
|
||||||
|
|
||||||
|
#### Preparation:
|
||||||
|
* Configure password-less ssh from the master node (the one you'll launch training from) to all other nodes.
|
||||||
|
|
||||||
|
* All hosts have the same working directory.
|
||||||
|
* All hosts have the same Python environment in the same location.
|
||||||
|
|
||||||
|
#### Train
|
||||||
|
Then the user may create a MPI Estimator as follows:
|
||||||
|
```python
|
||||||
|
from bigdl.orca.learn.mpi import MPIEstimator
|
||||||
|
|
||||||
|
est = MPIEstimator(model_creator=model_creator,
|
||||||
|
optimizer_creator=optimizer_creator,
|
||||||
|
loss_creator=None,
|
||||||
|
metrics=None,
|
||||||
|
scheduler_creator=None,
|
||||||
|
config=config,
|
||||||
|
init_func=init, # Init the distributed environment for MPI if any
|
||||||
|
hosts=hosts,
|
||||||
|
workers_per_node=workers_per_node,
|
||||||
|
env=None)
|
||||||
|
```
|
||||||
|
Then the user can perform distributed model training as follows:
|
||||||
|
```python
|
||||||
|
# read spark Dataframe
|
||||||
|
df = spark.read.parquet("data.parquet")
|
||||||
|
|
||||||
|
# distributed model training
|
||||||
|
est.fit(data=df, epochs=1, batch_size=4, feature_col="feature", label_cols="label")
|
||||||
|
|
||||||
|
```
|
||||||
|
The input to `fit` methods can be an Spark Dataframe, or a callable function to return a `torch.utils.data.DataLoader`.
|
||||||
|
|
||||||
|
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-mpi-mpi-estimator) for more details.
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,3 +77,11 @@ orca.learn.openvino.estimator
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
|
orca.learn.mpi.mpi_estimator
|
||||||
|
------------------------------
|
||||||
|
|
||||||
|
.. autoclass:: bigdl.orca.learn.mpi.MPIEstimator
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue