[Orca] Update Documents for Tf2estimator on Pyspark Backend (#4308)
* update tf2estimator on pyspark backend docs
This commit is contained in:
		
							parent
							
								
									5d4743a12a
								
							
						
					
					
						commit
						23aa10345f
					
				
					 1 changed files with 64 additions and 10 deletions
				
			
		| 
						 | 
				
			
			@ -60,11 +60,14 @@ predictions = est.predict(data=df,
 | 
			
		|||
```
 | 
			
		||||
The `data` argument in `fit` method can be a Spark DataFrame, an *XShards* or a `tf.data.Dataset`. The `data` argument in `predict` method can be a spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details.
 | 
			
		||||
 | 
			
		||||
View the related [Python API doc]() for more details.
 | 
			
		||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#module-bigdl.orca.learn.tf.estimator) for more details.
 | 
			
		||||
 | 
			
		||||
#### **2.2 TensorFlow 2.x and Keras 2.4+**
 | 
			
		||||
 | 
			
		||||
Users can create an `Estimator` for TensorFlow 2.x from a Keras model (using a _Model Creator Function_). For example:
 | 
			
		||||
**Using `tf2` or *Horovod* backend**
 | 
			
		||||
 | 
			
		||||
Users can create an `Estimator` for TensorFlow 2.x from a Keras model (using a _Model Creator Function_) when the backend is
 | 
			
		||||
`tf2` (currently default for TF2) or *Horovod*. For example:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
def model_creator(config):
 | 
			
		||||
| 
						 | 
				
			
			@ -73,7 +76,7 @@ def model_creator(config):
 | 
			
		|||
                  loss='sparse_categorical_crossentropy',
 | 
			
		||||
                  metrics=['accuracy'])
 | 
			
		||||
    return model
 | 
			
		||||
est = Estimator.from_keras(model_creator=model_creator)
 | 
			
		||||
est = Estimator.from_keras(model_creator=model_creator) # or backend="horovod"
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
The `model_creator` argument should be a function that takes a `config` dictionary and returns a compiled Keras model.
 | 
			
		||||
| 
						 | 
				
			
			@ -95,9 +98,60 @@ predictions = est.predict(data=df,
 | 
			
		|||
 | 
			
		||||
The `data` argument in `fit` method can be a spark DataFrame, an *XShards* or a *Data Creator Function* (that returns a `tf.data.Dataset`). The `data` argument in `predict` method can be a spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details.
 | 
			
		||||
 | 
			
		||||
View the related [Python API doc]() for more details.
 | 
			
		||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-tf2-tf2-ray-estimator) for more details.
 | 
			
		||||
 | 
			
		||||
***For more details, view the distributed TensorFlow training/inference [page]().***
 | 
			
		||||
**Using *spark* backend**
 | 
			
		||||
 | 
			
		||||
Users can create an `Estimator` for TensorFlow 2.x using the *spark* backend as follows:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
def model_creator(config):
 | 
			
		||||
    model = create_keras_lenet_model()
 | 
			
		||||
    model.compile(**compile_args(config))
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
def compile_args(config):
 | 
			
		||||
    if "lr" in config:
 | 
			
		||||
        lr = config["lr"]
 | 
			
		||||
    else:
 | 
			
		||||
        lr = 1e-2
 | 
			
		||||
    args = {
 | 
			
		||||
        "optimizer": keras.optimizers.SGD(lr),
 | 
			
		||||
        "loss": "mean_squared_error",
 | 
			
		||||
        "metrics": ["mean_squared_error"]
 | 
			
		||||
    }
 | 
			
		||||
    return args
 | 
			
		||||
 | 
			
		||||
est = Estimator.from_keras(model_creator=model_creator,
 | 
			
		||||
                           config={"lr": 1e-2},
 | 
			
		||||
                           workers_per_node=2,
 | 
			
		||||
                           backend="spark",
 | 
			
		||||
                           model_dir=model_dir)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
The `model_creator` argument should be a function that takes a `config` dictionary and returns a compiled Keras model.
 | 
			
		||||
The `model_dir` argument is required for *spark* backend, it should be a share filesystem path which can be accessed by executors for culster mode.  
 | 
			
		||||
 | 
			
		||||
Then users can perform distributed model training and inference as follows:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
def train_data_creator(config, batch_size):
 | 
			
		||||
    dataset = tfds.load(name="mnist", split="train")
 | 
			
		||||
    dataset = dataset.map(preprocess)
 | 
			
		||||
    dataset = dataset.batch(batch_size)
 | 
			
		||||
    return dataset
 | 
			
		||||
stats = est.fit(data=train_data_creator,
 | 
			
		||||
                epochs=max_epoch,
 | 
			
		||||
                steps_per_epoch=total_size // batch_size)
 | 
			
		||||
predictions = est.predict(data=df,
 | 
			
		||||
                          feature_cols=['image']).collect()
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
The `data` argument in `fit` method can be a spark DataFrame, an *XShards* or a *Data Creator Function* (that returns a `tf.data.Dataset`). The `data` argument in `predict` method can be a spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details.
 | 
			
		||||
 | 
			
		||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-tf2-tf2-spark-estimator) for more details.
 | 
			
		||||
 | 
			
		||||
***For more details, view the distributed TensorFlow training/inference [page]()<TODO: link to be added>.***
 | 
			
		||||
 | 
			
		||||
### **3. PyTorch Estimator**
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -123,7 +177,7 @@ predictions = est.predict(xshards)
 | 
			
		|||
 | 
			
		||||
The input to `fit` methods can be a `torch.utils.data.DataLoader`, a Spark Dataframe, an *XShards*, or a *Data Creator Function* (that returns a `torch.utils.data.DataLoader`). The input to `predict` methods should be a Spark Dataframe, or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details.
 | 
			
		||||
 | 
			
		||||
View the related [Python API doc]() for more details.
 | 
			
		||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-pytorch-pytorch-spark-estimator) for more details.
 | 
			
		||||
 | 
			
		||||
**Using `torch.distributed` or *Horovod* backend**
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -155,7 +209,7 @@ predictions = est.predict(data=df,
 | 
			
		|||
 | 
			
		||||
The input to `fit` methods can be a Spark DataFrame, an *XShards*, or a *Data Creator Function* (that returns a `torch.utils.data.DataLoader`). The `data` argument in `predict` method can be a Spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details.
 | 
			
		||||
 | 
			
		||||
View the related [Python API doc]() for more details.
 | 
			
		||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-pytorch-pytorch-ray-estimator) for more details.
 | 
			
		||||
 | 
			
		||||
***For more details, view the distributed PyTorch training/inference [page]()<TODO: link to be added>.***
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -194,7 +248,7 @@ est.fit(get_train_data_iter, epochs=2)
 | 
			
		|||
 | 
			
		||||
The input to `fit` methods can be an *XShards*, or a *Data Creator Function* (that returns an `MXNet DataIter/DataLoader`). See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details.
 | 
			
		||||
 | 
			
		||||
View the related [Python API doc]() for more details.
 | 
			
		||||
View the related [Python API doc]()<TODO: link to be added> for more details.
 | 
			
		||||
 | 
			
		||||
### **5. BigDL Estimator**
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -224,7 +278,7 @@ result_df = est.predict(df)
 | 
			
		|||
 | 
			
		||||
The input to `fit` and `predict` methods can be a *Spark Dataframe*, or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details.
 | 
			
		||||
 | 
			
		||||
View the related [Python API doc]() for more details.
 | 
			
		||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#module-bigdl.orca.learn.bigdl.estimator) for more details.
 | 
			
		||||
 | 
			
		||||
### **6. OpenVINO Estimator**
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -249,4 +303,4 @@ result_shards = est.predict(shards)
 | 
			
		|||
 | 
			
		||||
The input to `predict` methods can be an *XShards*, or a *numpy array*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details.
 | 
			
		||||
 | 
			
		||||
View the related [Python API doc]() for more details.
 | 
			
		||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-openvino-estimator) for more details.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue