From 23aa10345f95a64e4d0c92ad558e2273bf065ae7 Mon Sep 17 00:00:00 2001 From: SONG Ge <38711238+sgwhat@users.noreply.github.com> Date: Thu, 31 Mar 2022 09:49:56 +0800 Subject: [PATCH] [Orca] Update Documents for Tf2estimator on Pyspark Backend (#4308) * update tf2estimator on pyspark backend docs --- .../distributed-training-inference.md | 74 ++++++++++++++++--- 1 file changed, 64 insertions(+), 10 deletions(-) diff --git a/docs/readthedocs/source/doc/Orca/Overview/distributed-training-inference.md b/docs/readthedocs/source/doc/Orca/Overview/distributed-training-inference.md index c8b53084..299afc79 100644 --- a/docs/readthedocs/source/doc/Orca/Overview/distributed-training-inference.md +++ b/docs/readthedocs/source/doc/Orca/Overview/distributed-training-inference.md @@ -60,11 +60,14 @@ predictions = est.predict(data=df, ``` The `data` argument in `fit` method can be a Spark DataFrame, an *XShards* or a `tf.data.Dataset`. The `data` argument in `predict` method can be a spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details. -View the related [Python API doc]() for more details. +View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#module-bigdl.orca.learn.tf.estimator) for more details. #### **2.2 TensorFlow 2.x and Keras 2.4+** -Users can create an `Estimator` for TensorFlow 2.x from a Keras model (using a _Model Creator Function_). For example: +**Using `tf2` or *Horovod* backend** + +Users can create an `Estimator` for TensorFlow 2.x from a Keras model (using a _Model Creator Function_) when the backend is +`tf2` (currently default for TF2) or *Horovod*. For example: ```python def model_creator(config): @@ -73,7 +76,7 @@ def model_creator(config): loss='sparse_categorical_crossentropy', metrics=['accuracy']) return model -est = Estimator.from_keras(model_creator=model_creator) +est = Estimator.from_keras(model_creator=model_creator) # or backend="horovod" ``` The `model_creator` argument should be a function that takes a `config` dictionary and returns a compiled Keras model. @@ -95,9 +98,60 @@ predictions = est.predict(data=df, The `data` argument in `fit` method can be a spark DataFrame, an *XShards* or a *Data Creator Function* (that returns a `tf.data.Dataset`). The `data` argument in `predict` method can be a spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details. -View the related [Python API doc]() for more details. +View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-tf2-tf2-ray-estimator) for more details. -***For more details, view the distributed TensorFlow training/inference [page]().*** +**Using *spark* backend** + +Users can create an `Estimator` for TensorFlow 2.x using the *spark* backend as follows: + +```python +def model_creator(config): + model = create_keras_lenet_model() + model.compile(**compile_args(config)) + return model + +def compile_args(config): + if "lr" in config: + lr = config["lr"] + else: + lr = 1e-2 + args = { + "optimizer": keras.optimizers.SGD(lr), + "loss": "mean_squared_error", + "metrics": ["mean_squared_error"] + } + return args + +est = Estimator.from_keras(model_creator=model_creator, + config={"lr": 1e-2}, + workers_per_node=2, + backend="spark", + model_dir=model_dir) +``` + +The `model_creator` argument should be a function that takes a `config` dictionary and returns a compiled Keras model. +The `model_dir` argument is required for *spark* backend, it should be a share filesystem path which can be accessed by executors for culster mode. + +Then users can perform distributed model training and inference as follows: + +```python +def train_data_creator(config, batch_size): + dataset = tfds.load(name="mnist", split="train") + dataset = dataset.map(preprocess) + dataset = dataset.batch(batch_size) + return dataset +stats = est.fit(data=train_data_creator, + epochs=max_epoch, + steps_per_epoch=total_size // batch_size) +predictions = est.predict(data=df, + feature_cols=['image']).collect() +``` + +The `data` argument in `fit` method can be a spark DataFrame, an *XShards* or a *Data Creator Function* (that returns a `tf.data.Dataset`). The `data` argument in `predict` method can be a spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details. + +View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-tf2-tf2-spark-estimator) for more details. + +***For more details, view the distributed TensorFlow training/inference [page]().*** ### **3. PyTorch Estimator** @@ -123,7 +177,7 @@ predictions = est.predict(xshards) The input to `fit` methods can be a `torch.utils.data.DataLoader`, a Spark Dataframe, an *XShards*, or a *Data Creator Function* (that returns a `torch.utils.data.DataLoader`). The input to `predict` methods should be a Spark Dataframe, or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details. -View the related [Python API doc]() for more details. +View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-pytorch-pytorch-spark-estimator) for more details. **Using `torch.distributed` or *Horovod* backend** @@ -155,7 +209,7 @@ predictions = est.predict(data=df, The input to `fit` methods can be a Spark DataFrame, an *XShards*, or a *Data Creator Function* (that returns a `torch.utils.data.DataLoader`). The `data` argument in `predict` method can be a Spark DataFrame or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.md) for more details. -View the related [Python API doc]() for more details. +View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-pytorch-pytorch-ray-estimator) for more details. ***For more details, view the distributed PyTorch training/inference [page]().*** @@ -194,7 +248,7 @@ est.fit(get_train_data_iter, epochs=2) The input to `fit` methods can be an *XShards*, or a *Data Creator Function* (that returns an `MXNet DataIter/DataLoader`). See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details. -View the related [Python API doc]() for more details. +View the related [Python API doc]() for more details. ### **5. BigDL Estimator** @@ -224,7 +278,7 @@ result_df = est.predict(df) The input to `fit` and `predict` methods can be a *Spark Dataframe*, or an *XShards*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details. -View the related [Python API doc]() for more details. +View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#module-bigdl.orca.learn.bigdl.estimator) for more details. ### **6. OpenVINO Estimator** @@ -249,4 +303,4 @@ result_shards = est.predict(shards) The input to `predict` methods can be an *XShards*, or a *numpy array*. See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details. -View the related [Python API doc]() for more details. +View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-openvino-estimator) for more details.