Nano: TF multi process how-to for customized training loop (#8006)
* add how-to guide * add overview * fix doc * fix pep8 * update the notebook
This commit is contained in:
parent
9695ef2978
commit
2daaa6f7de
5 changed files with 57 additions and 4 deletions
|
|
@ -137,6 +137,7 @@ subtrees:
|
||||||
- file: doc/Nano/Howto/Training/TensorFlow/accelerate_tensorflow_training_multi_instance
|
- file: doc/Nano/Howto/Training/TensorFlow/accelerate_tensorflow_training_multi_instance
|
||||||
- file: doc/Nano/Howto/Training/TensorFlow/tensorflow_training_embedding_sparseadam
|
- file: doc/Nano/Howto/Training/TensorFlow/tensorflow_training_embedding_sparseadam
|
||||||
- file: doc/Nano/Howto/Training/TensorFlow/tensorflow_training_bf16
|
- file: doc/Nano/Howto/Training/TensorFlow/tensorflow_training_bf16
|
||||||
|
- file: doc/Nano/Howto/Training/TensorFlow/tensorflow_custom_training_multi_instance
|
||||||
- file: doc/Nano/Howto/Training/General/index
|
- file: doc/Nano/Howto/Training/General/index
|
||||||
title: "General"
|
title: "General"
|
||||||
subtrees:
|
subtrees:
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ Training Optimization: For TensorFlow Users
|
||||||
* `How to accelerate a TensorFlow Keras application on training workloads through multiple instances <accelerate_tensorflow_training_multi_instance.html>`_
|
* `How to accelerate a TensorFlow Keras application on training workloads through multiple instances <accelerate_tensorflow_training_multi_instance.html>`_
|
||||||
* |tensorflow_training_embedding_sparseadam_link|_
|
* |tensorflow_training_embedding_sparseadam_link|_
|
||||||
* `How to conduct BFloat16 Mixed Precision training in your TensorFlow application <tensorflow_training_bf16.html>`_
|
* `How to conduct BFloat16 Mixed Precision training in your TensorFlow application <tensorflow_training_bf16.html>`_
|
||||||
|
* `How to accelerate TensorFlow Keras customized training loop through multiple instances <tensorflow_custom_training_multi_instance.html>`_
|
||||||
|
|
||||||
.. |tensorflow_training_embedding_sparseadam_link| replace:: How to optimize your model with a sparse ``Embedding`` layer and ``SparseAdam`` optimizer
|
.. |tensorflow_training_embedding_sparseadam_link| replace:: How to optimize your model with a sparse ``Embedding`` layer and ``SparseAdam`` optimizer
|
||||||
.. _tensorflow_training_embedding_sparseadam_link: tensorflow_training_embedding_sparseadam.html
|
.. _tensorflow_training_embedding_sparseadam_link: tensorflow_training_embedding_sparseadam.html
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"path": "../../../../../../../../python/nano/tutorial/notebook/training/tensorflow/tensorflow_custom_training_multi_instance.ipynb"
|
||||||
|
}
|
||||||
|
|
@ -42,6 +42,7 @@ TensorFlow
|
||||||
* `How to accelerate a TensorFlow Keras application on training workloads through multiple instances <Training/TensorFlow/accelerate_tensorflow_training_multi_instance.html>`_
|
* `How to accelerate a TensorFlow Keras application on training workloads through multiple instances <Training/TensorFlow/accelerate_tensorflow_training_multi_instance.html>`_
|
||||||
* |tensorflow_training_embedding_sparseadam_link|_
|
* |tensorflow_training_embedding_sparseadam_link|_
|
||||||
* `How to conduct BFloat16 Mixed Precision training in your TensorFlow Keras application <Training/TensorFlow/tensorflow_training_bf16.html>`_
|
* `How to conduct BFloat16 Mixed Precision training in your TensorFlow Keras application <Training/TensorFlow/tensorflow_training_bf16.html>`_
|
||||||
|
* `How to accelerate TensorFlow Keras customized training loop through multiple instances <Training/TensorFlow/tensorflow_custom_training_multi_instance.html>`_
|
||||||
|
|
||||||
.. |tensorflow_training_embedding_sparseadam_link| replace:: How to optimize your model with a sparse ``Embedding`` layer and ``SparseAdam`` optimizer
|
.. |tensorflow_training_embedding_sparseadam_link| replace:: How to optimize your model with a sparse ``Embedding`` layer and ``SparseAdam`` optimizer
|
||||||
.. _tensorflow_training_embedding_sparseadam_link: Training/TensorFlow/tensorflow_training_embedding_sparseadam.html
|
.. _tensorflow_training_embedding_sparseadam_link: Training/TensorFlow/tensorflow_training_embedding_sparseadam.html
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,25 @@
|
||||||
# TensorFlow Training
|
# TensorFlow Training
|
||||||
|
|
||||||
BigDL-Nano can be used to accelerate TensorFlow Keras applications on training workloads. The optimizations in BigDL-Nano are delivered through BigDL-Nano's `Model` and `Sequential` classes, which have identical APIs with `tf.keras.Model` and `tf.keras.Sequential`. For most cases, you can just replace your `tf.keras.Model` with `bigdl.nano.tf.keras.Model` and `tf.keras.Sequential` with `bigdl.nano.tf.keras.Sequential` to benefit from BigDL-Nano.
|
BigDL-Nano can be used to accelerate TensorFlow Keras applications on training workloads. The optimizations in BigDL-Nano are delivered through
|
||||||
|
|
||||||
We will briefly describe here the major features in BigDL-Nano for TensorFlow training. You can find complete examples here [links to be added]().
|
- BigDL-Nano's `Model` and `Sequential` classes, which have identical APIs with `tf.keras.Model` and `tf.keras.Sequential` with an enhanced `fit` method.
|
||||||
|
- BigDL-Nano's decorator `nano` (potentially with the help of `nano_multiprocessing` and `nano_multiprocessing_loss`) to handle keras model with customized training loop.
|
||||||
|
|
||||||
|
We will briefly describe here the major features in BigDL-Nano for TensorFlow training.
|
||||||
|
|
||||||
### Best Known Configurations
|
### Best Known Configurations
|
||||||
When you install BigDL-Nano by `pip install bigdl-nano[tensorflow]`, `intel-tensorflow` will be installed in your environment, which has intel's oneDNN optimizations enabled by default; and when you run `source bigdl-nano-init`, it will export a few environment variables, such as `OMP_NUM_THREADS` and `KMP_AFFINITY`, according to your current hardware. Empirically, these environment variables work best for most TensorFlow applications. After setting these environment variables, you can just run your applications as usual (`python app.py`) and no additional changes are required.
|
When you install BigDL-Nano by `pip install bigdl-nano[tensorflow]`, `intel-tensorflow` will be installed in your environment, which has intel's oneDNN optimizations enabled by default; and when you run `source bigdl-nano-init`, it will export a few environment variables, such as `OMP_NUM_THREADS` and `KMP_AFFINITY`, according to your current hardware. Empirically, these environment variables work best for most TensorFlow applications. After setting these environment variables, you can just run your applications as usual (`python app.py`) and no additional changes are required.
|
||||||
|
|
||||||
### Multi-Instance Training
|
### Multi-Instance Training
|
||||||
|
When training on a server with dozens of CPU cores, it is often beneficial to use multiple training instances in a data-parallel fashion to make full use of the CPU cores. However
|
||||||
|
|
||||||
When training on a server with dozens of CPU cores, it is often beneficial to use multiple training instances in a data-parallel fashion to make full use of the CPU cores. However, naively using TensorFlow's `MultiWorkerMirroredStrategy` can cause conflict in CPU cores and often cannot provide performance benefits.
|
- Naively using TensorFlow's `MultiWorkerMirroredStrategy` can cause conflict in CPU cores and often cannot provide performance benefits.
|
||||||
|
- Customized training loop could be hard to use together with `MultiWorkerMirroredStrategy`
|
||||||
|
|
||||||
BigDL-Nano makes it very easy to conduct multi-instance training correctly. You can just set the `num_processes` parameter in the `fit` method in your `Model` or `Sequential` object and BigDL-Nano will launch the specific number of processes to perform data-parallel training. Each process will be automatically pinned to a different subset of CPU cores to avoid conflict and maximize training throughput.
|
BigDL-Nano makes it very easy to conduct multi-instance training correctly for default/customized training loop models.
|
||||||
|
|
||||||
|
#### Keras Model with default training loop
|
||||||
|
You can just set the `num_processes` parameter in the `fit` method in your `Model` or `Sequential` object and BigDL-Nano will launch the specific number of processes to perform data-parallel training. Each process will be automatically pinned to a different subset of CPU cores to avoid conflict and maximize training throughput.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
@ -38,6 +46,45 @@ model.compile(optimizer='adam',
|
||||||
model.fit(train_ds, epochs=3, validation_data=val_ds, num_processes=2)
|
model.fit(train_ds, epochs=3, validation_data=val_ds, num_processes=2)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Keras Model with customized training loop
|
||||||
|
|
||||||
|
To make them run in a multi-process way, you may only add 2 lines of code.
|
||||||
|
|
||||||
|
- add `nano_multiprocessing` to the `train_step` function with gradient calculation and applying process.
|
||||||
|
- add `@nano(num_processes=...)` to the training loop function with iteration over full dataset.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from bigdl.nano.tf.keras import nano_multiprocessing, nano
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
tf.random.set_seed(0)
|
||||||
|
global_batch_size = 32
|
||||||
|
|
||||||
|
model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
|
||||||
|
optimizer = tf.keras.optimizers.SGD()
|
||||||
|
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
|
||||||
|
|
||||||
|
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(128).batch(
|
||||||
|
global_batch_size)
|
||||||
|
|
||||||
|
@nano_multiprocessing # <-- Just remove this line to run on 1 process
|
||||||
|
@tf.function
|
||||||
|
def train_step(inputs, model, loss_object, optimizer):
|
||||||
|
features, labels = inputs
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
predictions = model(features, training=True)
|
||||||
|
loss = loss_object(labels, predictions)
|
||||||
|
gradients = tape.gradient(loss, model.trainable_variables)
|
||||||
|
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@nano(num_processes=4) # <-- Just remove this line to run on 1 process
|
||||||
|
def train_whole_data(model, dataset, loss_object, optimizer, train_step):
|
||||||
|
for inputs in dataset:
|
||||||
|
print(train_step(inputs, model, loss_object, optimizer))
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
Note that, different from the conventions in [BigDL-Nano PyTorch multi-instance training](./pytorch_train.html#multi-instance-training), the effective batch size will not change in TensorFlow multi-instance training, which means it is still the batch size you specify in your dataset. This is because TensorFlow's `MultiWorkerMirroredStrategy` will try to split the batch into multiple sub-batches for different workers. We chose this behavior to match the semantics of TensorFlow distributed training.
|
Note that, different from the conventions in [BigDL-Nano PyTorch multi-instance training](./pytorch_train.html#multi-instance-training), the effective batch size will not change in TensorFlow multi-instance training, which means it is still the batch size you specify in your dataset. This is because TensorFlow's `MultiWorkerMirroredStrategy` will try to split the batch into multiple sub-batches for different workers. We chose this behavior to match the semantics of TensorFlow distributed training.
|
||||||
|
|
||||||
When you do want to increase your effective `batch_size`, you can do so by directly changing it in your dataset definition and you may also want to gradually increase the learning rate linearly to the `batch_size`, as described in this [paper](https://arxiv.org/abs/1706.02677) published by Facebook.
|
When you do want to increase your effective `batch_size`, you can do so by directly changing it in your dataset definition and you may also want to gradually increase the learning rate linearly to the `batch_size`, as described in this [paper](https://arxiv.org/abs/1706.02677) published by Facebook.
|
||||||
Loading…
Reference in a new issue