Update MPI Estimator to support Pytorch IPEX training (#8303)
* update * update * update * update * update * update with comments * update * update * style * style * add doc * style * style
This commit is contained in:
		
							parent
							
								
									35fdf94031
								
							
						
					
					
						commit
						5d90ca2dac
					
				
					 2 changed files with 48 additions and 0 deletions
				
			
		| 
						 | 
					@ -304,3 +304,43 @@ result_shards = est.predict(shards)
 | 
				
			||||||
The input to `predict` methods can be an *XShards*, or a *numpy array*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details.
 | 
					The input to `predict` methods can be an *XShards*, or a *numpy array*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-openvino-estimator) for more details.
 | 
					View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-openvino-estimator) for more details.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 7. MPI Estimator
 | 
				
			||||||
 | 
					The Orca MPI Estimator is to run distributed training job based on MPI.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#### Preparation:
 | 
				
			||||||
 | 
					* Configure password-less ssh from the master node (the one you'll launch training from) to all other nodes.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					* All hosts have the same working directory.
 | 
				
			||||||
 | 
					* All hosts have the same Python environment in the same location.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#### Train
 | 
				
			||||||
 | 
					Then the user may create a MPI Estimator as follows:
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					from bigdl.orca.learn.mpi import MPIEstimator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					est = MPIEstimator(model_creator=model_creator,
 | 
				
			||||||
 | 
					                   optimizer_creator=optimizer_creator,
 | 
				
			||||||
 | 
					                   loss_creator=None,
 | 
				
			||||||
 | 
					                   metrics=None,
 | 
				
			||||||
 | 
					                   scheduler_creator=None,
 | 
				
			||||||
 | 
					                   config=config,
 | 
				
			||||||
 | 
					                   init_func=init,  # Init the distributed environment for MPI if any
 | 
				
			||||||
 | 
					                   hosts=hosts,
 | 
				
			||||||
 | 
					                   workers_per_node=workers_per_node,
 | 
				
			||||||
 | 
					                   env=None)                   
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					Then the user can perform distributed model training as follows:
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					# read spark Dataframe
 | 
				
			||||||
 | 
					df = spark.read.parquet("data.parquet")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# distributed model training
 | 
				
			||||||
 | 
					est.fit(data=df, epochs=1, batch_size=4, feature_col="feature", label_cols="label")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					The input to `fit` methods can be an Spark Dataframe, or a callable  function to return a `torch.utils.data.DataLoader`. 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-mpi-mpi-estimator) for more details.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -77,3 +77,11 @@ orca.learn.openvino.estimator
 | 
				
			||||||
    :members:
 | 
					    :members:
 | 
				
			||||||
    :undoc-members:
 | 
					    :undoc-members:
 | 
				
			||||||
    :show-inheritance:
 | 
					    :show-inheritance:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					orca.learn.mpi.mpi_estimator
 | 
				
			||||||
 | 
					------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					.. autoclass:: bigdl.orca.learn.mpi.MPIEstimator
 | 
				
			||||||
 | 
					    :members:
 | 
				
			||||||
 | 
					    :undoc-members:
 | 
				
			||||||
 | 
					    :show-inheritance:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue