Orca: update doc for pytorch estimator backend. (#6723)
* feat: update doc for pytorch estimator backend. * fix: remove ray global dependency. * rm: remove .swp file. * fix: revert ray import fix. * fix: replace model and optimizer with model_creator and optimizer_creator. * fix: delete unnecessary links. * fix: update index.md * fix: fix code style of quickstart and jupyter notebook. * fix: remove criterion. * fix: fix dataset description. * fix: fix code style. * fix: fix code style. * fix: update batch size and link * fix: update link * fix: fix code style. * fix: fix unnecessary code. * fix: fix typo. * fix: use relative path. * fix: fix typo. * fix: fix link.
This commit is contained in:
		
							parent
							
								
									ca3b088522
								
							
						
					
					
						commit
						745aaef5df
					
				
					 5 changed files with 208 additions and 67 deletions
				
			
		| 
						 | 
					@ -151,21 +151,25 @@ The `data` argument in `fit` method can be a spark DataFrame, an *XShards* or a
 | 
				
			||||||
 | 
					
 | 
				
			||||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-tf2-tf2-spark-estimator) for more details.
 | 
					View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-tf2-tf2-spark-estimator) for more details.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
***For more details, view the distributed TensorFlow training/inference [page]()<TODO: link to be added>.***
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### 3. PyTorch Estimator
 | 
					### 3. PyTorch Estimator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
**Using *BigDL* backend**
 | 
					**Using *BigDL* backend**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Users may create a PyTorch `Estimator` using the *BigDL* backend (currently default for PyTorch) as follows:
 | 
					Users may create a PyTorch `Estimator` using the *Spark* backend (currently default for PyTorch) as follows:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
 | 
					def model_creator(config):
 | 
				
			||||||
    model = LeNet() # a torch.nn.Module
 | 
					    model = LeNet() # a torch.nn.Module
 | 
				
			||||||
    model.train()
 | 
					    model.train()
 | 
				
			||||||
criterion = nn.NLLLoss()
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
adam = torch.optim.Adam(model.parameters(), args.lr)
 | 
					def optimizer_creator(model, config):
 | 
				
			||||||
est = Estimator.from_torch(model=model, optimizer=adam, loss=criterion)
 | 
					    return torch.optim.Adam(model.parameters(), config["lr"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					est = Estimator.from_torch(model=model_creator,
 | 
				
			||||||
 | 
					                           optimizer=optimizer_creator,
 | 
				
			||||||
 | 
					                           loss=nn.NLLLoss(),
 | 
				
			||||||
 | 
					                           config={"lr": 1e-2})
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Then users can perform distributed model training and inference as follows:
 | 
					Then users can perform distributed model training and inference as follows:
 | 
				
			||||||
| 
						 | 
					@ -192,7 +196,7 @@ def model_creator(config):
 | 
				
			||||||
def optimizer_creator(model, config):
 | 
					def optimizer_creator(model, config):
 | 
				
			||||||
    return torch.optim.Adam(model.parameters(), config["lr"])
 | 
					    return torch.optim.Adam(model.parameters(), config["lr"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
est = Estimator.from_torch(model=model,
 | 
					est = Estimator.from_torch(model=model_creator,
 | 
				
			||||||
                           optimizer=optimizer_creator,
 | 
					                           optimizer=optimizer_creator,
 | 
				
			||||||
                           loss=nn.NLLLoss(),
 | 
					                           loss=nn.NLLLoss(),
 | 
				
			||||||
                           config={"lr": 1e-2},
 | 
					                           config={"lr": 1e-2},
 | 
				
			||||||
| 
						 | 
					@ -211,8 +215,6 @@ The input to `fit` methods can be a Spark DataFrame, an *XShards*, or a *Data Cr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-pytorch-pytorch-ray-estimator) for more details.
 | 
					View the related [Python API doc](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Orca/orca.html#orca-learn-pytorch-pytorch-ray-estimator) for more details.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
***For more details, view the distributed PyTorch training/inference [page]()<TODO: link to be added>.***
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### 4. MXNet Estimator
 | 
					### 4. MXNet Estimator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
The user may create a MXNet `Estimator` as follows:
 | 
					The user may create a MXNet `Estimator` as follows:
 | 
				
			||||||
| 
						 | 
					@ -248,8 +250,6 @@ est.fit(get_train_data_iter, epochs=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
The input to `fit` methods can be an *XShards*, or a *Data Creator Function* (that returns an `MXNet DataIter/DataLoader`). See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details.
 | 
					The input to `fit` methods can be an *XShards*, or a *Data Creator Function* (that returns an `MXNet DataIter/DataLoader`). See the *data-parallel processing pipeline* [page](./data-parallel-processing.html) for more details.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
View the related [Python API doc]()<TODO: link to be added> for more details.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### 5. BigDL Estimator
 | 
					### 5. BigDL Estimator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
The user may create a BigDL `Estimator` as follows:
 | 
					The user may create a BigDL `Estimator` as follows:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -11,7 +11,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- [**PyTorch Quickstart**](./orca-pytorch-quickstart.html)
 | 
					- [**PyTorch Quickstart**](./orca-pytorch-quickstart.html)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    > [Run in Google Colab](https://colab.research.google.com/github/intel-analytics/BigDL/blob/main/python/orca/colab-notebook/quickstart/pytorch_lenet_mnist.ipynb)  [View source on GitHub](https://github.com/intel-analytics/BigDL/blob/main/python/orca/colab-notebook/quickstart/pytorch_lenet_mnist.ipynb)
 | 
					    > [Run in Google Colab](https://colab.research.google.com/github/intel-analytics/BigDL/blob/main/python/orca/colab-notebook/quickstart/pytorch_lenet_mnist.ipynb)  [View source on GitHub](https://github.com/intel-analytics/BigDL/blob/main/python/orca/colab-notebook/quickstart/pytorch_lenet_mnist_spark.ipynb)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    In this guide we will describe how to scale out PyTorch programs using Orca in 5 simple steps.
 | 
					    In this guide we will describe how to scale out PyTorch programs using Orca in 5 simple steps.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,149 @@
 | 
				
			||||||
 | 
					# PyTorch Quickstart
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[Run in Google Colab](https://colab.research.google.com/github/intel-analytics/BigDL/blob/main/python/orca/colab-notebook/quickstart/pytorch_lenet_mnist_bigdl.ipynb)  [View source on GitHub](https://github.com/intel-analytics/BigDL/blob/main/python/orca/colab-notebook/quickstart/pytorch_lenet_mnist_bigdl.ipynb)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**In this guide we will describe how to scale out _PyTorch_ programs using Orca in 4 simple steps.**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Step 0: Prepare Environment
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[Conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/) is needed to prepare the Python environment for running this example. Please refer to the [install guide](../../UserGuide/python.md) for more details.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					conda create -n py37 python=3.7  # "py37" is conda environment name, you can use any name you like.
 | 
				
			||||||
 | 
					conda activate py37
 | 
				
			||||||
 | 
					pip install bigdl-orca
 | 
				
			||||||
 | 
					pip install torch==1.7.1 torchvision==0.8.2
 | 
				
			||||||
 | 
					pip install six cloudpickle
 | 
				
			||||||
 | 
					pip install jep==3.9.0
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Step 1: Init Orca Context
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					from bigdl.orca import init_orca_context, stop_orca_context
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cluster_mode = "local"
 | 
				
			||||||
 | 
					if cluster_mode == "local":  # For local machine
 | 
				
			||||||
 | 
					    init_orca_context(cores=4, memory="10g")
 | 
				
			||||||
 | 
					elif cluster_mode == "k8s":  # For K8s cluster
 | 
				
			||||||
 | 
					    init_orca_context(cluster_mode="k8s", num_nodes=2, cores=2, memory="10g", driver_memory="10g", driver_cores=1)
 | 
				
			||||||
 | 
					elif cluster_mode == "yarn":  # For Hadoop/YARN cluster
 | 
				
			||||||
 | 
					    init_orca_context(
 | 
				
			||||||
 | 
					    cluster_mode="yarn", cores=2, num_nodes=2, memory="10g",
 | 
				
			||||||
 | 
					    driver_memory="10g", driver_cores=1,
 | 
				
			||||||
 | 
					    conf={"spark.rpc.message.maxSize": "1024",
 | 
				
			||||||
 | 
					        "spark.task.maxFailures": "1",
 | 
				
			||||||
 | 
					        "spark.driver.extraJavaOptions": "-Dbigdl.failure.retryTimes=1"})
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This is the only place where you need to specify local or distributed mode. View [Orca Context](./../Overview/orca-context.md) for more details.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Note:** You should `export HADOOP_CONF_DIR=/path/to/hadoop/conf/dir` when running on Hadoop YARN cluster. View [Hadoop User Guide](./../../UserGuide/hadoop.md) for more details.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Step 2: Define the Model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					You may define your model, loss and optimizer in the same way as in any standard (single node) PyTorch program.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LeNet(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        super(LeNet, self).__init__()
 | 
				
			||||||
 | 
					        self.conv1 = nn.Conv2d(1, 20, 5, 1)
 | 
				
			||||||
 | 
					        self.conv2 = nn.Conv2d(20, 50, 5, 1)
 | 
				
			||||||
 | 
					        self.fc1 = nn.Linear(4*4*50, 500)
 | 
				
			||||||
 | 
					        self.fc2 = nn.Linear(500, 10)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        x = F.relu(self.conv1(x))
 | 
				
			||||||
 | 
					        x = F.max_pool2d(x, 2, 2)
 | 
				
			||||||
 | 
					        x = F.relu(self.conv2(x))
 | 
				
			||||||
 | 
					        x = F.max_pool2d(x, 2, 2)
 | 
				
			||||||
 | 
					        x = x.view(-1, 4*4*50)
 | 
				
			||||||
 | 
					        x = F.relu(self.fc1(x))
 | 
				
			||||||
 | 
					        x = self.fc2(x)
 | 
				
			||||||
 | 
					        return F.log_softmax(x, dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					model = LeNet()
 | 
				
			||||||
 | 
					model.train()
 | 
				
			||||||
 | 
					criterion = nn.NLLLoss()
 | 
				
			||||||
 | 
					adam = torch.optim.Adam(model.parameters(), 0.001)
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Step 3: Define Train Dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					You can define the dataset using standard [Pytorch DataLoader](https://pytorch.org/docs/stable/data.html). 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torchvision import datasets, transforms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					torch.manual_seed(0)
 | 
				
			||||||
 | 
					dir='./'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					batch_size=64
 | 
				
			||||||
 | 
					test_batch_size=64
 | 
				
			||||||
 | 
					train_loader = torch.utils.data.DataLoader(
 | 
				
			||||||
 | 
					    datasets.MNIST(dir, train=True, download=True,
 | 
				
			||||||
 | 
					                   transform=transforms.Compose([
 | 
				
			||||||
 | 
					                       transforms.ToTensor(),
 | 
				
			||||||
 | 
					                       transforms.Normalize((0.1307,), (0.3081,))
 | 
				
			||||||
 | 
					                   ])),
 | 
				
			||||||
 | 
					    batch_size=batch_size, shuffle=True)
 | 
				
			||||||
 | 
					test_loader = torch.utils.data.DataLoader(
 | 
				
			||||||
 | 
					    datasets.MNIST(dir, train=False,
 | 
				
			||||||
 | 
					                   transform=transforms.Compose([
 | 
				
			||||||
 | 
					                       transforms.ToTensor(),
 | 
				
			||||||
 | 
					                       transforms.Normalize((0.1307,), (0.3081,))
 | 
				
			||||||
 | 
					                   ])),
 | 
				
			||||||
 | 
					    batch_size=test_batch_size, shuffle=False)
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Alternatively, we can also use a [Data Creator Function](https://github.com/intel-analytics/BigDL/blob/main/docs/docs/colab-notebook/orca/quickstart/pytorch_lenet_mnist_data_creator_func.ipynb) or [Orca XShards](../Overview/data-parallel-processing) as the input data, especially when the data size is very large)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Step 4: Fit with Orca Estimator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					First, Create an Estimator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					from bigdl.orca.learn.pytorch import Estimator 
 | 
				
			||||||
 | 
					from bigdl.orca.learn.metrics import Accuracy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					est = Estimator.from_torch(model=model, optimizer=adam, loss=criterion, metrics=[Accuracy()])
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Next, fit and evaluate using the Estimator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					from bigdl.orca.learn.trigger import EveryEpoch 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					est.fit(data=train_loader, epochs=10, validation_data=test_loader,
 | 
				
			||||||
 | 
					        checkpoint_trigger=EveryEpoch())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					result = est.evaluate(data=test_loader)
 | 
				
			||||||
 | 
					for r in result:
 | 
				
			||||||
 | 
					    print(r, ":", result[r])
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Step 5: Save and Load the Model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Save the Estimator states (including model and optimizer) to the provided model path.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					est.save("mnist_model")
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Load the Estimator states (model and possibly with optimizer) from the provided model path.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					est.load("mnist_model")
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Note:** You should call `stop_orca_context()` when your application finishes.
 | 
				
			||||||
| 
						 | 
					@ -2,7 +2,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
---
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[Run in Google Colab](https://colab.research.google.com/github/intel-analytics/BigDL/blob/main/python/orca/colab-notebook/quickstart/pytorch_distributed_lenet_mnist.ipynb)  [View source on GitHub](https://github.com/intel-analytics/BigDL/blob/main/python/orca/colab-notebook/quickstart/pytorch_distributed_lenet_mnist.ipynb)
 | 
					[Run in Google Colab](https://colab.research.google.com/github/intel-analytics/BigDL/blob/main/python/orca/colab-notebook/quickstart/pytorch_lenet_mnist_ray.ipynb)  [View source on GitHub](https://github.com/intel-analytics/BigDL/blob/main/python/orca/colab-notebook/quickstart/pytorch_lenet_mnist_ray.ipynb)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
---
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,24 +2,22 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
---
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[Run in Google Colab](https://colab.research.google.com/github/intel-analytics/BigDL/blob/main/python/orca/colab-notebook/quickstart/pytorch_lenet_mnist.ipynb)  [View source on GitHub](https://github.com/intel-analytics/BigDL/blob/main/python/orca/colab-notebook/quickstart/pytorch_lenet_mnist.ipynb)
 | 
					[Run in Google Colab](https://colab.research.google.com/github/intel-analytics/BigDL/blob/main/python/orca/colab-notebook/quickstart/pytorch_lenet_mnist_spark.ipynb)  [View source on GitHub](https://github.com/intel-analytics/BigDL/blob/main/python/orca/colab-notebook/quickstart/pytorch_lenet_mnist_spark.ipynb)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
---
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
**In this guide we will describe how to scale out _PyTorch_ programs using Orca in 4 simple steps.**
 | 
					**In this guide we will describe how to scale out _PyTorch_ programs using Orca in 5 simple steps.**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### Step 0: Prepare Environment
 | 
					### Step 0: Prepare Environment
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[Conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/) is needed to prepare the Python environment for running this example. Please refer to the [install guide](../../UserGuide/python.md) for more details.
 | 
					[Conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/) is needed to prepare the Python environment for running this example. Please refer to the [install guide](../../UserGuide/python.md) for more details.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
conda create -n py37 python=3.7  # "py37" is conda environment name, you can use any name you like.
 | 
					conda create -n py37 python=3.7  # "py37" is conda environment name, you can use any name you like.
 | 
				
			||||||
conda activate py37
 | 
					conda activate py37
 | 
				
			||||||
pip install bigdl-orca
 | 
					pip install --pre --upgrade bigdl-orca 
 | 
				
			||||||
pip install torch==1.7.1 torchvision==0.8.2
 | 
					pip install torch torchvision
 | 
				
			||||||
pip install six cloudpickle
 | 
					pip install tqdm
 | 
				
			||||||
pip install jep==3.9.0
 | 
					 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### Step 1: Init Orca Context
 | 
					### Step 1: Init Orca Context
 | 
				
			||||||
| 
						 | 
					@ -32,12 +30,7 @@ if cluster_mode == "local":  # For local machine
 | 
				
			||||||
elif cluster_mode == "k8s":  # For K8s cluster
 | 
					elif cluster_mode == "k8s":  # For K8s cluster
 | 
				
			||||||
    init_orca_context(cluster_mode="k8s", num_nodes=2, cores=2, memory="10g", driver_memory="10g", driver_cores=1)
 | 
					    init_orca_context(cluster_mode="k8s", num_nodes=2, cores=2, memory="10g", driver_memory="10g", driver_cores=1)
 | 
				
			||||||
elif cluster_mode == "yarn":  # For Hadoop/YARN cluster
 | 
					elif cluster_mode == "yarn":  # For Hadoop/YARN cluster
 | 
				
			||||||
    init_orca_context(
 | 
					    init_orca_context(cluster_mode="yarn", num_nodes=2, cores=2, memory="10g", driver_memory="10g", driver_cores=1)
 | 
				
			||||||
    cluster_mode="yarn", cores=2, num_nodes=2, memory="10g",
 | 
					 | 
				
			||||||
    driver_memory="10g", driver_cores=1,
 | 
					 | 
				
			||||||
    conf={"spark.rpc.message.maxSize": "1024",
 | 
					 | 
				
			||||||
        "spark.task.maxFailures": "1",
 | 
					 | 
				
			||||||
        "spark.driver.extraJavaOptions": "-Dbigdl.failure.retryTimes=1"})
 | 
					 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
This is the only place where you need to specify local or distributed mode. View [Orca Context](./../Overview/orca-context.md) for more details.
 | 
					This is the only place where you need to specify local or distributed mode. View [Orca Context](./../Overview/orca-context.md) for more details.
 | 
				
			||||||
| 
						 | 
					@ -70,26 +63,30 @@ class LeNet(nn.Module):
 | 
				
			||||||
        x = F.relu(self.fc1(x))
 | 
					        x = F.relu(self.fc1(x))
 | 
				
			||||||
        x = self.fc2(x)
 | 
					        x = self.fc2(x)
 | 
				
			||||||
        return F.log_softmax(x, dim=1)
 | 
					        return F.log_softmax(x, dim=1)
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					After defining your model, you need to define a *Model Creator Function* that takes the parameter `config` and returns an instance of your model, and a *Optimizer Creator Function* that has two parameters `model` and `config` and returns a PyTorch optimizer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					def model_creator(config):
 | 
				
			||||||
    model = LeNet()
 | 
					    model = LeNet()
 | 
				
			||||||
model.train()
 | 
					    return model
 | 
				
			||||||
criterion = nn.NLLLoss()
 | 
					
 | 
				
			||||||
adam = torch.optim.Adam(model.parameters(), 0.001)
 | 
					def optim_creator(model, config):
 | 
				
			||||||
 | 
					    return torch.optim.Adam(model.parameters(), lr=0.001)
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### Step 3: Define Train Dataset
 | 
					### Step 3: Define Train Dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
You can define the dataset using standard [Pytorch DataLoader](https://pytorch.org/docs/stable/data.html). 
 | 
					You can define the dataset using a *Data Creator Function* that has two parameters `config` and `batch_size` and returns a [Pytorch DataLoader](https://pytorch.org/docs/stable/data.html). Orca also supports [Spark DataFrames](../Overview/data-parallel-processing.html#spark-dataframes) and [XShards](../Overview/data-parallel-processing.html#xshards-distributed-data-parallel-python-processing).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torchvision import datasets, transforms
 | 
					from torchvision import datasets, transforms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
torch.manual_seed(0)
 | 
					 | 
				
			||||||
dir='./'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
batch_size = 64
 | 
					batch_size = 64
 | 
				
			||||||
test_batch_size=64
 | 
					dir = '/tmp/dataset'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def train_loader_creator(config, batch_size):
 | 
				
			||||||
    train_loader = torch.utils.data.DataLoader(
 | 
					    train_loader = torch.utils.data.DataLoader(
 | 
				
			||||||
        datasets.MNIST(dir, train=True, download=True,
 | 
					        datasets.MNIST(dir, train=True, download=True,
 | 
				
			||||||
                       transform=transforms.Compose([
 | 
					                       transform=transforms.Compose([
 | 
				
			||||||
| 
						 | 
					@ -97,17 +94,19 @@ train_loader = torch.utils.data.DataLoader(
 | 
				
			||||||
                           transforms.Normalize((0.1307,), (0.3081,))
 | 
					                           transforms.Normalize((0.1307,), (0.3081,))
 | 
				
			||||||
                       ])),
 | 
					                       ])),
 | 
				
			||||||
        batch_size=batch_size, shuffle=True)
 | 
					        batch_size=batch_size, shuffle=True)
 | 
				
			||||||
 | 
					    return train_loader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_loader_creator(config, batch_size):
 | 
				
			||||||
    test_loader = torch.utils.data.DataLoader(
 | 
					    test_loader = torch.utils.data.DataLoader(
 | 
				
			||||||
        datasets.MNIST(dir, train=False,
 | 
					        datasets.MNIST(dir, train=False,
 | 
				
			||||||
                       transform=transforms.Compose([
 | 
					                       transform=transforms.Compose([
 | 
				
			||||||
                           transforms.ToTensor(),
 | 
					                           transforms.ToTensor(),
 | 
				
			||||||
                           transforms.Normalize((0.1307,), (0.3081,))
 | 
					                           transforms.Normalize((0.1307,), (0.3081,))
 | 
				
			||||||
                       ])),
 | 
					                       ])),
 | 
				
			||||||
    batch_size=test_batch_size, shuffle=False)
 | 
					        batch_size=batch_size, shuffle=False)
 | 
				
			||||||
 | 
					    return test_loader
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Alternatively, we can also use a [Data Creator Function](https://github.com/intel-analytics/BigDL/blob/main/docs/docs/colab-notebook/orca/quickstart/pytorch_lenet_mnist_data_creator_func.ipynb) or [Orca XShards](../Overview/data-parallel-processing) as the input data, especially when the data size is very large)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### Step 4: Fit with Orca Estimator
 | 
					### Step 4: Fit with Orca Estimator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
First, Create an Estimator
 | 
					First, Create an Estimator
 | 
				
			||||||
| 
						 | 
					@ -116,34 +115,27 @@ First, Create an Estimator
 | 
				
			||||||
from bigdl.orca.learn.pytorch import Estimator 
 | 
					from bigdl.orca.learn.pytorch import Estimator 
 | 
				
			||||||
from bigdl.orca.learn.metrics import Accuracy
 | 
					from bigdl.orca.learn.metrics import Accuracy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
est = Estimator.from_torch(model=model, optimizer=adam, loss=criterion, metrics=[Accuracy()], backend="bigdl")
 | 
					est = Estimator.from_torch(model=model_creator, optimizer=optim_creator, loss=nn.NLLLoss(), metrics=[Accuracy()], use_tqdm=True)
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Next, fit and evaluate using the Estimator
 | 
					Next, fit and evaluate using the Estimator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
from bigdl.orca.learn.trigger import EveryEpoch 
 | 
					est.fit(data=train_loader_creator, epochs=1, batch_size=batch_size)
 | 
				
			||||||
 | 
					result = est.evaluate(data=test_loader_creator, batch_size=batch_size)
 | 
				
			||||||
est.fit(data=train_loader, epochs=10, validation_data=test_loader,
 | 
					 | 
				
			||||||
        checkpoint_trigger=EveryEpoch())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
result = est.evaluate(data=test_loader)
 | 
					 | 
				
			||||||
for r in result:
 | 
					for r in result:
 | 
				
			||||||
    print(r, ":", result[r])
 | 
					    print(r, ":", result[r])
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### Step 5: Save and Load the Model
 | 
					### Step 5: Save the Model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Save the Estimator states (including model and optimizer) to the provided model path.
 | 
					Save the Estimator states (including model and optimizer) to the provided model path.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
est.save("mnist_model")
 | 
					est.save("mnist_model")
 | 
				
			||||||
```
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
Load the Estimator states (model and possibly with optimizer) from the provided model path.
 | 
					# stop orca context when program finishes
 | 
				
			||||||
 | 
					stop_orca_context()
 | 
				
			||||||
```python
 | 
					 | 
				
			||||||
est.load("mnist_model")
 | 
					 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
**Note:** You should call `stop_orca_context()` when your application finishes.
 | 
					**Note:** You should call `stop_orca_context()` when your application finishes.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue