[Orca] Update Documents for Tf2estimator on Pyspark Backend (#4308)

* update tf2estimator on pyspark backend docs
This commit is contained in:
SONG Ge 2022-03-31 09:49:56 +08:00 committed by GitHub
parent 5d4743a12a
commit 23aa10345f

View file

@ -60,11 +60,14 @@ predictions = est.predict(data=df,
```
The `data` argument in `fit` method can be a Spark DataFrame, an *XShards* or a `tf.data.Dataset`. The `data` argument in `predict` method can be a spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details.
View the related [Python API doc]() for more details.
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#module-bigdl.orca.learn.tf.estimator) for more details.
#### **2.2 TensorFlow 2.x and Keras 2.4+**
Users can create an `Estimator` for TensorFlow 2.x from a Keras model (using a _Model Creator Function_). For example:
**Using `tf2` or *Horovod* backend**
Users can create an `Estimator` for TensorFlow 2.x from a Keras model (using a _Model Creator Function_) when the backend is
`tf2` (currently default for TF2) or *Horovod*. For example:
```python
def model_creator(config):
@ -73,7 +76,7 @@ def model_creator(config):
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
est = Estimator.from_keras(model_creator=model_creator)
est = Estimator.from_keras(model_creator=model_creator) # or backend="horovod"
```
The `model_creator` argument should be a function that takes a `config` dictionary and returns a compiled Keras model.
@ -95,9 +98,60 @@ predictions = est.predict(data=df,
The `data` argument in `fit` method can be a spark DataFrame, an *XShards* or a *Data Creator Function* (that returns a `tf.data.Dataset`). The `data` argument in `predict` method can be a spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details.
View the related [Python API doc]() for more details.
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-tf2-tf2-ray-estimator) for more details.
***For more details, view the distributed TensorFlow training/inference [page]().***
**Using *spark* backend**
Users can create an `Estimator` for TensorFlow 2.x using the *spark* backend as follows:
```python
def model_creator(config):
model = create_keras_lenet_model()
model.compile(**compile_args(config))
return model
def compile_args(config):
if "lr" in config:
lr = config["lr"]
else:
lr = 1e-2
args = {
"optimizer": keras.optimizers.SGD(lr),
"loss": "mean_squared_error",
"metrics": ["mean_squared_error"]
}
return args
est = Estimator.from_keras(model_creator=model_creator,
config={"lr": 1e-2},
workers_per_node=2,
backend="spark",
model_dir=model_dir)
```
The `model_creator` argument should be a function that takes a `config` dictionary and returns a compiled Keras model.
The `model_dir` argument is required for *spark* backend, it should be a share filesystem path which can be accessed by executors for culster mode.
Then users can perform distributed model training and inference as follows:
```python
def train_data_creator(config, batch_size):
dataset = tfds.load(name="mnist", split="train")
dataset = dataset.map(preprocess)
dataset = dataset.batch(batch_size)
return dataset
stats = est.fit(data=train_data_creator,
epochs=max_epoch,
steps_per_epoch=total_size // batch_size)
predictions = est.predict(data=df,
feature_cols=['image']).collect()
```
The `data` argument in `fit` method can be a spark DataFrame, an *XShards* or a *Data Creator Function* (that returns a `tf.data.Dataset`). The `data` argument in `predict` method can be a spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details.
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-tf2-tf2-spark-estimator) for more details.
***For more details, view the distributed TensorFlow training/inference [page]()<TODO: link to be added>.***
### **3. PyTorch Estimator**
@ -123,7 +177,7 @@ predictions = est.predict(xshards)
The input to `fit` methods can be a `torch.utils.data.DataLoader`, a Spark Dataframe, an *XShards*, or a *Data Creator Function* (that returns a `torch.utils.data.DataLoader`). The input to `predict` methods should be a Spark Dataframe, or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details.
View the related [Python API doc]() for more details.
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-pytorch-pytorch-spark-estimator) for more details.
**Using `torch.distributed` or *Horovod* backend**
@ -155,7 +209,7 @@ predictions = est.predict(data=df,
The input to `fit` methods can be a Spark DataFrame, an *XShards*, or a *Data Creator Function* (that returns a `torch.utils.data.DataLoader`). The `data` argument in `predict` method can be a Spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details.
View the related [Python API doc]() for more details.
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-pytorch-pytorch-ray-estimator) for more details.
***For more details, view the distributed PyTorch training/inference [page]()<TODO: link to be added>.***
@ -194,7 +248,7 @@ est.fit(get_train_data_iter, epochs=2)
The input to `fit` methods can be an *XShards*, or a *Data Creator Function* (that returns an `MXNet DataIter/DataLoader`). See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details.
View the related [Python API doc]() for more details.
View the related [Python API doc]()<TODO: link to be added> for more details.
### **5. BigDL Estimator**
@ -224,7 +278,7 @@ result_df = est.predict(df)
The input to `fit` and `predict` methods can be a *Spark Dataframe*, or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details.
View the related [Python API doc]() for more details.
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#module-bigdl.orca.learn.bigdl.estimator) for more details.
### **6. OpenVINO Estimator**
@ -249,4 +303,4 @@ result_shards = est.predict(shards)
The input to `predict` methods can be an *XShards*, or a *numpy array*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details.
View the related [Python API doc]() for more details.
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-openvino-estimator) for more details.