diff --git a/docs/readthedocs/source/doc/Orca/QuickStart/orca-pytorch-quickstart.md b/docs/readthedocs/source/doc/Orca/QuickStart/orca-pytorch-quickstart.md index d2886431..700f27fe 100644 --- a/docs/readthedocs/source/doc/Orca/QuickStart/orca-pytorch-quickstart.md +++ b/docs/readthedocs/source/doc/Orca/QuickStart/orca-pytorch-quickstart.md @@ -26,6 +26,7 @@ pip install jep==3.9.0 ```python from bigdl.orca import init_orca_context, stop_orca_context +cluster_mode = "local" if cluster_mode == "local": # For local machine init_orca_context(cores=4, memory="10g") elif cluster_mode == "k8s": # For K8s cluster @@ -59,7 +60,7 @@ class LeNet(nn.Module): self.conv2 = nn.Conv2d(20, 50, 5, 1) self.fc1 = nn.Linear(4*4*50, 500) self.fc2 = nn.Linear(500, 10) - + def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 2, 2) @@ -69,7 +70,7 @@ class LeNet(nn.Module): x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1) - + model = LeNet() model.train() criterion = nn.NLLLoss()