You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: sagemaker-python-sdk/tensorflow_distributed_mnist/tensorflow_distributed_mnist.ipynb
+70-23Lines changed: 70 additions & 23 deletions
Original file line number
Diff line number
Diff line change
@@ -4,7 +4,17 @@
4
4
"cell_type": "markdown",
5
5
"metadata": {},
6
6
"source": [
7
-
"## Let's start by setting up the environment."
7
+
"# MNIST distributed training \n",
8
+
"\n",
9
+
"The **SageMaker Python SDK** helps you deploy your models for training and hosting in optimized, productions ready containers in SageMaker. The SageMaker Python SDK is easy to use, modular, extensible and compatible with TensorFlow and MXNet. This tutorial focuses on how to create a convolutional neural network model to train the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) using **TensorFlow distributed training**.\n",
10
+
"\n"
11
+
]
12
+
},
13
+
{
14
+
"cell_type": "markdown",
15
+
"metadata": {},
16
+
"source": [
17
+
"### Set up the environment"
8
18
]
9
19
},
10
20
{
@@ -20,15 +30,14 @@
20
30
"\n",
21
31
"sagemaker_session = sagemaker.Session()\n",
22
32
"\n",
23
-
"# Replace with a role (either name or full arn) that gives SageMaker access to S3 and cloudwatch\n",
24
-
"role='SageMakerRole'"
33
+
"role = get_execution_role()"
25
34
]
26
35
},
27
36
{
28
37
"cell_type": "markdown",
29
38
"metadata": {},
30
39
"source": [
31
-
"## Downloading test and training data"
40
+
"### Download the MNIST dataset"
32
41
]
33
42
},
34
43
{
@@ -54,7 +63,8 @@
54
63
"cell_type": "markdown",
55
64
"metadata": {},
56
65
"source": [
57
-
"## Uploading the data"
66
+
"### Upload the data\n",
67
+
"We use the ```sagemaker.Session.upload_data``` function to upload our datasets to an S3 location. The return value inputs identifies the location -- we will use this later when we start the training job."
58
68
]
59
69
},
60
70
{
@@ -72,7 +82,8 @@
72
82
"cell_type": "markdown",
73
83
"metadata": {},
74
84
"source": [
75
-
"# Complete source code"
85
+
"# Construct a script for distributed training \n",
86
+
"Here is the full code for the network model:"
76
87
]
77
88
},
78
89
{
@@ -90,12 +101,36 @@
90
101
"cell_type": "markdown",
91
102
"metadata": {},
92
103
"source": [
93
-
"# Running TensorFlow training on SageMaker\n",
104
+
"The script here is and adaptation of the [TensorFlow MNIST example](https://github.com/tensorflow/models/tree/master/official/mnist). It provides a ```model_fn(features, labels, mode)```, which is used for training, evaluation and inference. \n",
105
+
"\n",
106
+
"## A regular ```model_fn```\n",
94
107
"\n",
95
-
"We can use the SDK to run our local training script on SageMaker infrastructure.\n",
108
+
"A regular **```model_fn```** follows the pattern:\n",
109
+
"1. [defines a neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L96)\n",
110
+
"- [applies the ```features``` in the neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L178)\n",
111
+
"- [if the ```mode``` is ```PREDICT```, returns the output from the neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L186)\n",
112
+
"- [calculates the loss function comparing the output with the ```labels```](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L188)\n",
113
+
"- [creates an optimizer and minimizes the loss function to improve the neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L193)\n",
114
+
"- [returns the output, optimizer and loss function](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L205)\n",
96
115
"\n",
97
-
"1. Pass the path to the abalone.py file, which contains the functions for defining your estimator, to the sagemaker.TensorFlow init method.\n",
98
-
"2. Pass the S3 location that we uploaded our data to previously to the fit() method."
116
+
"## Writing writint a ```model_fn``` for distributed training\n",
117
+
"When distributed training happens, the same neural network will be sent to the multiple training instances. Each instance will predict a batch of the dataset, calculate loss and minimize the optimizer. One entire loop of this process is called **training step**.\n",
118
+
"\n",
119
+
"### Syncronizing training steps\n",
120
+
"A [global step](https://www.tensorflow.org/api_docs/python/tf/train/global_step) it is a global variable shared between the instances. It necessary for distributed training, so the optimizer will keep track of the number of **training steps** between runs: \n",
"The **```fit```** method will create a training job in two **ml.c4.xlarge** instances. The logs above will show the instances doing training, evaluation, and incrementing the number of **training steps**. \n",
161
+
"\n",
162
+
"In the end of the training, the training job will generate a saved model for TF serving."
"notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.",
"notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
0 commit comments