[Orca] Update Documents for Tf2estimator on Pyspark Backend (#4308)
* update tf2estimator on pyspark backend docs
This commit is contained in:
parent
5d4743a12a
commit
23aa10345f
1 changed files with 64 additions and 10 deletions
|
|
@ -60,11 +60,14 @@ predictions = est.predict(data=df,
|
|||
```
|
||||
The `data` argument in `fit` method can be a Spark DataFrame, an *XShards* or a `tf.data.Dataset`. The `data` argument in `predict` method can be a spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details.
|
||||
|
||||
View the related [Python API doc]() for more details.
|
||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#module-bigdl.orca.learn.tf.estimator) for more details.
|
||||
|
||||
#### **2.2 TensorFlow 2.x and Keras 2.4+**
|
||||
|
||||
Users can create an `Estimator` for TensorFlow 2.x from a Keras model (using a _Model Creator Function_). For example:
|
||||
**Using `tf2` or *Horovod* backend**
|
||||
|
||||
Users can create an `Estimator` for TensorFlow 2.x from a Keras model (using a _Model Creator Function_) when the backend is
|
||||
`tf2` (currently default for TF2) or *Horovod*. For example:
|
||||
|
||||
```python
|
||||
def model_creator(config):
|
||||
|
|
@ -73,7 +76,7 @@ def model_creator(config):
|
|||
loss='sparse_categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
return model
|
||||
est = Estimator.from_keras(model_creator=model_creator)
|
||||
est = Estimator.from_keras(model_creator=model_creator) # or backend="horovod"
|
||||
```
|
||||
|
||||
The `model_creator` argument should be a function that takes a `config` dictionary and returns a compiled Keras model.
|
||||
|
|
@ -95,9 +98,60 @@ predictions = est.predict(data=df,
|
|||
|
||||
The `data` argument in `fit` method can be a spark DataFrame, an *XShards* or a *Data Creator Function* (that returns a `tf.data.Dataset`). The `data` argument in `predict` method can be a spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details.
|
||||
|
||||
View the related [Python API doc]() for more details.
|
||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-tf2-tf2-ray-estimator) for more details.
|
||||
|
||||
***For more details, view the distributed TensorFlow training/inference [page]().***
|
||||
**Using *spark* backend**
|
||||
|
||||
Users can create an `Estimator` for TensorFlow 2.x using the *spark* backend as follows:
|
||||
|
||||
```python
|
||||
def model_creator(config):
|
||||
model = create_keras_lenet_model()
|
||||
model.compile(**compile_args(config))
|
||||
return model
|
||||
|
||||
def compile_args(config):
|
||||
if "lr" in config:
|
||||
lr = config["lr"]
|
||||
else:
|
||||
lr = 1e-2
|
||||
args = {
|
||||
"optimizer": keras.optimizers.SGD(lr),
|
||||
"loss": "mean_squared_error",
|
||||
"metrics": ["mean_squared_error"]
|
||||
}
|
||||
return args
|
||||
|
||||
est = Estimator.from_keras(model_creator=model_creator,
|
||||
config={"lr": 1e-2},
|
||||
workers_per_node=2,
|
||||
backend="spark",
|
||||
model_dir=model_dir)
|
||||
```
|
||||
|
||||
The `model_creator` argument should be a function that takes a `config` dictionary and returns a compiled Keras model.
|
||||
The `model_dir` argument is required for *spark* backend, it should be a share filesystem path which can be accessed by executors for culster mode.
|
||||
|
||||
Then users can perform distributed model training and inference as follows:
|
||||
|
||||
```python
|
||||
def train_data_creator(config, batch_size):
|
||||
dataset = tfds.load(name="mnist", split="train")
|
||||
dataset = dataset.map(preprocess)
|
||||
dataset = dataset.batch(batch_size)
|
||||
return dataset
|
||||
stats = est.fit(data=train_data_creator,
|
||||
epochs=max_epoch,
|
||||
steps_per_epoch=total_size // batch_size)
|
||||
predictions = est.predict(data=df,
|
||||
feature_cols=['image']).collect()
|
||||
```
|
||||
|
||||
The `data` argument in `fit` method can be a spark DataFrame, an *XShards* or a *Data Creator Function* (that returns a `tf.data.Dataset`). The `data` argument in `predict` method can be a spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details.
|
||||
|
||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-tf2-tf2-spark-estimator) for more details.
|
||||
|
||||
***For more details, view the distributed TensorFlow training/inference [page]()<TODO: link to be added>.***
|
||||
|
||||
### **3. PyTorch Estimator**
|
||||
|
||||
|
|
@ -123,7 +177,7 @@ predictions = est.predict(xshards)
|
|||
|
||||
The input to `fit` methods can be a `torch.utils.data.DataLoader`, a Spark Dataframe, an *XShards*, or a *Data Creator Function* (that returns a `torch.utils.data.DataLoader`). The input to `predict` methods should be a Spark Dataframe, or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details.
|
||||
|
||||
View the related [Python API doc]() for more details.
|
||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-pytorch-pytorch-spark-estimator) for more details.
|
||||
|
||||
**Using `torch.distributed` or *Horovod* backend**
|
||||
|
||||
|
|
@ -155,7 +209,7 @@ predictions = est.predict(data=df,
|
|||
|
||||
The input to `fit` methods can be a Spark DataFrame, an *XShards*, or a *Data Creator Function* (that returns a `torch.utils.data.DataLoader`). The `data` argument in `predict` method can be a Spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details.
|
||||
|
||||
View the related [Python API doc]() for more details.
|
||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-pytorch-pytorch-ray-estimator) for more details.
|
||||
|
||||
***For more details, view the distributed PyTorch training/inference [page]()<TODO: link to be added>.***
|
||||
|
||||
|
|
@ -194,7 +248,7 @@ est.fit(get_train_data_iter, epochs=2)
|
|||
|
||||
The input to `fit` methods can be an *XShards*, or a *Data Creator Function* (that returns an `MXNet DataIter/DataLoader`). See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details.
|
||||
|
||||
View the related [Python API doc]() for more details.
|
||||
View the related [Python API doc]()<TODO: link to be added> for more details.
|
||||
|
||||
### **5. BigDL Estimator**
|
||||
|
||||
|
|
@ -224,7 +278,7 @@ result_df = est.predict(df)
|
|||
|
||||
The input to `fit` and `predict` methods can be a *Spark Dataframe*, or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details.
|
||||
|
||||
View the related [Python API doc]() for more details.
|
||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#module-bigdl.orca.learn.bigdl.estimator) for more details.
|
||||
|
||||
### **6. OpenVINO Estimator**
|
||||
|
||||
|
|
@ -249,4 +303,4 @@ result_shards = est.predict(shards)
|
|||
|
||||
The input to `predict` methods can be an *XShards*, or a *numpy array*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details.
|
||||
|
||||
View the related [Python API doc]() for more details.
|
||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-openvino-estimator) for more details.
|
||||
|
|
|
|||
Loading…
Reference in a new issue