Skip to content

Commit a10de71

Browse files
authored
revise mxnet_mnist example (#58)
1 parent 48bc02e commit a10de71

File tree

1 file changed

+95
-21
lines changed

1 file changed

+95
-21
lines changed

sagemaker-python-sdk/mxnet_mnist/mxnet_mnist.ipynb

Lines changed: 95 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,81 @@
44
"cell_type": "markdown",
55
"metadata": {},
66
"source": [
7-
"## Mxnet MNIST Single Machine SageMaker Training Example\n",
7+
"# Training and hosting SageMaker Models using the Apache MXNet Module API\n",
88
"\n",
9-
"MNIST is a widely used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). The task at hand is to train a model using the 60,000 training images and subsequently test its classification accuracy on the 10,000 test images.\n",
10-
"\n"
9+
"The **SageMaker Python SDK** makes it easy to train and deploy MXNet models. In this example, we train a simple neural network using the Apache MXNet [Module API](https://mxnet.incubator.apache.org/api/python/module.html) and the MNIST dataset. The MNIST dataset is widely used for handwritten digit classification, and consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). The task at hand is to train a model using the 60,000 training images and subsequently test its classification accuracy on the 10,000 test images.\n",
10+
"\n",
11+
"### Setup\n",
12+
"\n",
13+
"First we need to define a few variables that will be needed later in the example."
1114
]
1215
},
1316
{
1417
"cell_type": "code",
1518
"execution_count": null,
1619
"metadata": {
20+
"collapsed": true,
1721
"isConfigCell": true
1822
},
1923
"outputs": [],
2024
"source": [
2125
"from sagemaker import get_execution_role\n",
2226
"\n",
2327
"#Bucket location to save your custom code in tar.gz format.\n",
24-
"custom_code_upload_location = 's3://<bucket-name>/customcode/mxnet_mnist'\n",
28+
"custom_code_upload_location = 's3://<bucket-name>/customcode/tensorflow_iris'\n",
29+
"\n",
2530
"#Bucket location where results of model training are saved.\n",
2631
"model_artifacts_location = 's3://<bucket-name>/artifacts'\n",
2732
"\n",
33+
"#IAM execution role that gives SageMaker access to resources in your AWS account.\n",
34+
"#We can use the SageMaker Python SDK to get the role from our notebook environment. \n",
2835
"role = get_execution_role()"
2936
]
3037
},
3138
{
3239
"cell_type": "markdown",
3340
"metadata": {},
3441
"source": [
35-
"The ```MXNet``` class allows us to run single machine, multi-machine, and GPU mxnet training on SageMaker. Below we create an MXNet object to run our mnist training, passing in an IAMRole name to allow SageMaker to access our AWS resources. We run SageMaker mxnet training on a single ```m4.xlarge```.\n",
42+
"### The training script\n",
3643
"\n",
37-
"Please see the ```mnist.py``` script to learn more about how training is performed. The script is an adaptation of the mxnet MNIST tutorial, found here: https://mxnet.incubator.apache.org/tutorials/python/mnist.html"
44+
"The ``mnist.py`` script provides all the code we need for training and hosting a SageMaker model. The script we will use is adaptated from Apache MXNet [MNIST tutorial (https://mxnet.incubator.apache.org/tutorials/python/mnist.html)."
3845
]
3946
},
4047
{
4148
"cell_type": "code",
4249
"execution_count": null,
50+
"metadata": {
51+
"collapsed": true
52+
},
53+
"outputs": [],
54+
"source": [
55+
"!cat mnist.py"
56+
]
57+
},
58+
{
59+
"cell_type": "markdown",
60+
"metadata": {},
61+
"source": [
62+
"### SageMaker's MXNet estimator class"
63+
]
64+
},
65+
{
66+
"cell_type": "markdown",
4367
"metadata": {},
68+
"source": [
69+
"The SageMaker ```MXNet``` estimator allows us to run single machine or distributed training in SageMaker, using CPU or GPU-based instances.\n",
70+
"\n",
71+
"When we create the estimator, we pass in the filename of our training script, the name of our IAM execution role, and the S3 locations we defined in the setup section. We also provide the a few other parameters. ``train_instance_count`` and ``train_instance_type`` determine the number and type of SageMaker instances that will be used for the training job. The ``hyperparameters`` parameter is a ``dict`` of values that will be passed to your training script -- you can see how to access these values in the ``mnist.py`` script above.\n",
72+
"\n",
73+
"For this example, we will choose one ``ml.m4.xlarge`` instance."
74+
]
75+
},
76+
{
77+
"cell_type": "code",
78+
"execution_count": null,
79+
"metadata": {
80+
"collapsed": true
81+
},
4482
"outputs": [],
4583
"source": [
4684
"from sagemaker.mxnet import MXNet\n",
@@ -51,24 +89,30 @@
5189
" code_location=custom_code_upload_location,\n",
5290
" train_instance_count=1, \n",
5391
" train_instance_type='ml.m4.xlarge',\n",
54-
" hyperparameters={'learning_rate': 0.11})"
92+
" hyperparameters={'learning_rate': 0.1})"
5593
]
5694
},
5795
{
5896
"cell_type": "markdown",
5997
"metadata": {},
6098
"source": [
61-
"- TODO: Make the ECR images this is using public\n",
62-
"\n",
63-
"After we've constructed our MXNet object, we can fit it using data stored in S3. Below we run SageMaker training on two input channels: train and test.\n",
99+
"### Running the Training Job"
100+
]
101+
},
102+
{
103+
"cell_type": "markdown",
104+
"metadata": {},
105+
"source": [
106+
"After we've constructed our MXNet object, we can fit it using data stored in S3. Below we run SageMaker training on two input channels: **train** and **test**.\n",
64107
"\n",
65-
"During training, SageMaker makes this data stored in S3 available in the local filesystem where the mnist script is running. The ```mnist.py``` script simply loads the train and test data from disk.\n"
108+
"During training, SageMaker makes this data stored in S3 available in the local filesystem where the mnist script is running. The ```mnist.py``` script simply loads the train and test data from disk."
66109
]
67110
},
68111
{
69112
"cell_type": "code",
70113
"execution_count": null,
71114
"metadata": {
115+
"collapsed": true,
72116
"scrolled": true
73117
},
74118
"outputs": [],
@@ -87,15 +131,19 @@
87131
"cell_type": "markdown",
88132
"metadata": {},
89133
"source": [
90-
"After training, we use the MXNet object to build and deploy an MXNetPredictor object. This creates an sagemaker-hosted prediction service that we can use to perform inference. \n",
134+
"### Creating an inference Endpoint\n",
135+
"\n",
136+
"After training, we use the ``MXNet estimator`` object to build and deploy an ``MXNetPredictor``. This creates a Sagemaker **Endpoint** -- a hosted prediction service that we can use to perform inference. \n",
91137
"\n",
92-
"This allows us to perform inference on json encoded multi-dimensional arrays. "
138+
"The arguments to the ``deploy`` function allow us to set the number and type of instances that will be used for the Endpoint. These do not need to be the same as the values we used for the training job. For example, you can train a model on a set of GPU-based instances, and then deploy the Endpoint to a fleet of CPU-based instances. Here we will deploy the model to a single ``ml.c4.xlarge`` instance."
93139
]
94140
},
95141
{
96142
"cell_type": "code",
97143
"execution_count": null,
98-
"metadata": {},
144+
"metadata": {
145+
"collapsed": true
146+
},
99147
"outputs": [],
100148
"source": [
101149
"%%time\n",
@@ -108,23 +156,41 @@
108156
"cell_type": "markdown",
109157
"metadata": {},
110158
"source": [
111-
"We can now use this predictor to classify hand-written digits. Drawing into the image box loads the pixel data into a 'data' variable in this notebook, which we can then pass to the mxnet predictor. "
159+
"The request handling behavior of the Endpoint is determined by the ``mnist.py`` script. In this case, the script doesn't include any request handling functions, so the Endpoint will use the default handlers provided by SageMaker. These default handlers allow us to perform inference on input data encoded as a multi-dimensional JSON array.\n",
160+
"\n",
161+
"### Making an inference request\n",
162+
"\n",
163+
"Now that our Endpoint is deployed and we have a ``predictor`` object, we can use it to classify handwritten digits.\n",
164+
"\n",
165+
"To see inference in action, draw a digit in the image box below. The pixel data from your drawing will be loaded into a ``data`` variable in this notebook. \n",
166+
"\n",
167+
"*Note: after drawing the image, you'll need to move to the next notebook cell.*"
112168
]
113169
},
114170
{
115171
"cell_type": "code",
116172
"execution_count": null,
117-
"metadata": {},
173+
"metadata": {
174+
"collapsed": true
175+
},
118176
"outputs": [],
119177
"source": [
120178
"from IPython.display import HTML\n",
121179
"HTML(open(\"input.html\").read())"
122180
]
123181
},
182+
{
183+
"cell_type": "markdown",
184+
"metadata": {},
185+
"source": [
186+
"Now we can use the ``predictor`` object to classify the handwritten digit:"
187+
]
188+
},
124189
{
125190
"cell_type": "code",
126191
"execution_count": null,
127192
"metadata": {
193+
"collapsed": true,
128194
"scrolled": true
129195
},
130196
"outputs": [],
@@ -147,22 +213,28 @@
147213
"collapsed": true
148214
},
149215
"source": [
150-
"# (Optional) Delete the Endpoint"
216+
"# (Optional) Delete the Endpoint\n",
217+
"\n",
218+
"After you have finished with this example, remember to delete the prediction endpoint to release the instance(s) associated with it."
151219
]
152220
},
153221
{
154222
"cell_type": "code",
155223
"execution_count": null,
156-
"metadata": {},
224+
"metadata": {
225+
"collapsed": true
226+
},
157227
"outputs": [],
158228
"source": [
159-
"print(predictor.endpoint)"
229+
"print(\"Endpoint name: \" + predictor.endpoint)"
160230
]
161231
},
162232
{
163233
"cell_type": "code",
164234
"execution_count": null,
165-
"metadata": {},
235+
"metadata": {
236+
"collapsed": true
237+
},
166238
"outputs": [],
167239
"source": [
168240
"import sagemaker\n",
@@ -173,7 +245,9 @@
173245
{
174246
"cell_type": "code",
175247
"execution_count": null,
176-
"metadata": {},
248+
"metadata": {
249+
"collapsed": true
250+
},
177251
"outputs": [],
178252
"source": []
179253
}

0 commit comments

Comments
 (0)