|
4 | 4 | "cell_type": "markdown",
|
5 | 5 | "metadata": {},
|
6 | 6 | "source": [
|
7 |
| - "# MNIST training with PyTorch\n", |
| 7 | + "# Train an MNIST model with PyTorch\n", |
8 | 8 | "\n",
|
9 |
| - "MNIST is a widely used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). This tutorial will show how to train and test an MNIST model on SageMaker using PyTorch. \n", |
10 |
| - "\n" |
| 9 | + "MNIST is a widely used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). This tutorial shows how to train and test an MNIST model on SageMaker using PyTorch. \n", |
| 10 | + "\n", |
| 11 | + "## Runtime\n", |
| 12 | + "\n", |
| 13 | + "This notebook takes approximately 5 minutes to run.\n", |
| 14 | + "\n", |
| 15 | + "## Contents\n", |
| 16 | + "\n", |
| 17 | + "1. [PyTorch Estimator](#PyTorch-Estimator)\n", |
| 18 | + "1. [Implement the entry point for training](#Implement-the-entry-point-for-training)\n", |
| 19 | + "1. [Set hyperparameters](#Set-hyperparameters)\n", |
| 20 | + "1. [Set up channels for the training and testing data](#Set-up-channels-for-the-training-and-testing-data)\n", |
| 21 | + "1. [Run the training script on SageMaker](#Run-the-training-script-on-SageMaker)\n", |
| 22 | + "1. [Inspect and store model data](#Inspect-and-store-model-data)\n", |
| 23 | + "1. [Test and debug the entry point before executing the training container](#Test-and-debug-the-entry-point-before-executing-the-training-container)\n", |
| 24 | + "1. [Conclusion](#Conclusion)" |
11 | 25 | ]
|
12 | 26 | },
|
13 | 27 | {
|
|
28 | 42 | "\n",
|
29 | 43 | "role = get_execution_role()\n",
|
30 | 44 | "\n",
|
31 |
| - "output_path = \"s3://\" + sess.default_bucket() + \"/mnist\"" |
| 45 | + "output_path = \"s3://\" + sess.default_bucket() + \"/DEMO-mnist\"" |
32 | 46 | ]
|
33 | 47 | },
|
34 | 48 | {
|
|
44 | 58 | "You need to configure\n",
|
45 | 59 | "it with the following parameters to set up the environment:\n",
|
46 | 60 | "\n",
|
47 |
| - "- entry_point: A user defined python file to be used by the training container as the \n", |
| 61 | + "- `entry_point`: A user-defined Python file used by the training container as the \n", |
48 | 62 | "instructions for training. We further discuss this file in the next subsection.\n",
|
49 | 63 | "\n",
|
50 |
| - "- role: An IAM role to make AWS service requests\n", |
| 64 | + "- `role`: An IAM role to make AWS service requests\n", |
51 | 65 | "\n",
|
52 |
| - "- instance_type: The type of SageMaker instance to run your training script. \n", |
| 66 | + "- `instance_type`: The type of SageMaker instance to run your training script. \n", |
53 | 67 | "Set it to `local` if you want to run the training job on \n",
|
54 | 68 | "the SageMaker instance you are using to run this notebook\n",
|
55 | 69 | "\n",
|
56 |
| - "- instance count: The number of instances you need to run your training job. \n", |
| 70 | + "- `instance_count`: The number of instances to run your training job on. \n", |
57 | 71 | "Multiple instances are needed for distributed training.\n",
|
58 | 72 | "\n",
|
59 |
| - "- output_path: \n", |
| 73 | + "- `output_path`: \n", |
60 | 74 | "S3 bucket URI to save training output (model artifacts and output files)\n",
|
61 | 75 | "\n",
|
62 |
| - "- framework_version: The version of PyTorch you need to use.\n", |
| 76 | + "- `framework_version`: The version of PyTorch to use\n", |
63 | 77 | "\n",
|
64 |
| - "- py_version: The python version you need to use\n", |
| 78 | + "- `py_version`: The Python version to use\n", |
65 | 79 | "\n",
|
66 |
| - "For more information, see [the API reference](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.EstimatorBase)\n", |
| 80 | + "For more information, see the [EstimatorBase API reference](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.EstimatorBase)\n", |
67 | 81 | "\n"
|
68 | 82 | ]
|
69 | 83 | },
|
|
73 | 87 | "source": [
|
74 | 88 | "## Implement the entry point for training\n",
|
75 | 89 | "\n",
|
76 |
| - "The entry point for training is a python script that provides all \n", |
| 90 | + "The entry point for training is a Python script that provides all \n", |
77 | 91 | "the code for training a PyTorch model. It is used by the SageMaker \n",
|
78 | 92 | "PyTorch Estimator (`PyTorch` class above) as the entry point for running the training job.\n",
|
79 | 93 | "\n",
|
80 | 94 | "Under the hood, SageMaker PyTorch Estimator creates a docker image\n",
|
81 | 95 | "with runtime environemnts \n",
|
82 |
| - "specified by the parameters you used to initiated the\n", |
83 |
| - "estimator class and it injects the training script into the \n", |
84 |
| - "docker image to be used as the entry point to run the container.\n", |
| 96 | + "specified by the parameters you provide to initiate the\n", |
| 97 | + "estimator class, and it injects the training script into the \n", |
| 98 | + "docker image as the entry point to run the container.\n", |
85 | 99 | "\n",
|
86 | 100 | "In the rest of the notebook, we use *training image* to refer to the \n",
|
87 | 101 | "docker image specified by the PyTorch Estimator and *training container*\n",
|
88 | 102 | "to refer to the container that runs the training image. \n",
|
89 | 103 | "\n",
|
90 | 104 | "This means your training script is very similar to a training script\n",
|
91 | 105 | "you might run outside Amazon SageMaker, but it can access the useful environment \n",
|
92 |
| - "variables provided by the training image. Checkout [the short list of environment variables provided by the SageMaker service](https://sagemaker.readthedocs.io/en/stable/frameworks/mxnet/using_mxnet.html?highlight=entry%20point) to see some common environment \n", |
93 |
| - "variables you might used. Checkout [the complete list of environment variables](https://github.com/aws/sagemaker-training-toolkit/blob/master/ENVIRONMENT_VARIABLES.md) for a complete \n", |
| 106 | + "variables provided by the training image. See [the complete list of environment variables](https://github.com/aws/sagemaker-training-toolkit/blob/master/ENVIRONMENT_VARIABLES.md) for a complete \n", |
94 | 107 | "description of all environment variables your training script\n",
|
95 |
| - "can access to. \n", |
| 108 | + "can access. \n", |
96 | 109 | "\n",
|
97 | 110 | "In this example, we use the training script `code/train.py`\n",
|
98 | 111 | "as the entry point for our PyTorch Estimator.\n"
|
|
111 | 124 | "cell_type": "markdown",
|
112 | 125 | "metadata": {},
|
113 | 126 | "source": [
|
114 |
| - "### Set hyperparameters\n", |
| 127 | + "## Set hyperparameters\n", |
115 | 128 | "\n",
|
116 |
| - "In addition, PyTorch estimator allows you to parse command line arguments\n", |
| 129 | + "In addition, the PyTorch estimator allows you to parse command line arguments\n", |
117 | 130 | "to your training script via `hyperparameters`.\n",
|
118 | 131 | "\n",
|
119 |
| - "<span style=\"color:red\"> Note: local mode is not supported in SageMaker Studio </span>" |
| 132 | + "Note: local mode is not supported in SageMaker Studio. " |
120 | 133 | ]
|
121 | 134 | },
|
122 | 135 | {
|
|
125 | 138 | "metadata": {},
|
126 | 139 | "outputs": [],
|
127 | 140 | "source": [
|
128 |
| - "# set local_mode to be True if you want to run the training script\n", |
129 |
| - "# on the machine that runs this notebook\n", |
| 141 | + "# Set local_mode to True to run the training script on the machine that runs this notebook\n", |
130 | 142 | "\n",
|
131 | 143 | "local_mode = False\n",
|
132 | 144 | "\n",
|
|
153 | 165 | "cell_type": "markdown",
|
154 | 166 | "metadata": {},
|
155 | 167 | "source": [
|
156 |
| - "The training container executes your training script like\n", |
| 168 | + "The training container executes your training script like:\n", |
157 | 169 | "\n",
|
158 | 170 | "```\n",
|
159 |
| - "python train.py --batch-size 100 --epochs 1 --learning-rate 1e-3 \\\n", |
160 |
| - " --log-interval 100\n", |
| 171 | + "python train.py --batch-size 100 --epochs 1 --learning-rate 1e-3 --log-interval 100\n", |
161 | 172 | "```"
|
162 | 173 | ]
|
163 | 174 | },
|
164 | 175 | {
|
165 | 176 | "cell_type": "markdown",
|
166 | 177 | "metadata": {},
|
167 | 178 | "source": [
|
168 |
| - "## Set up channels for training and testing data\n", |
| 179 | + "## Set up channels for the training and testing data\n", |
169 | 180 | "\n",
|
170 |
| - "You need to tell `PyTorch` estimator where to find your training and \n", |
171 |
| - "testing data. It can be a link to an S3 bucket or it can be a path\n", |
| 181 | + "Tell the `PyTorch` estimator where to find the training and \n", |
| 182 | + "testing data. It can be a path to an S3 bucket, or a path\n", |
172 | 183 | "in your local file system if you use local mode. In this example,\n",
|
173 | 184 | "we download the MNIST data from a public S3 bucket and upload it \n",
|
174 | 185 | "to your default bucket. "
|
|
184 | 195 | "import boto3\n",
|
185 | 196 | "from botocore.exceptions import ClientError\n",
|
186 | 197 | "\n",
|
187 |
| - "\n", |
188 | 198 | "# Download training and testing data from a public S3 bucket\n",
|
189 | 199 | "\n",
|
190 | 200 | "\n",
|
191 |
| - "def download_from_s3(data_dir=\"/tmp/data\", train=True):\n", |
| 201 | + "def download_from_s3(data_dir=\"./data\", train=True):\n", |
192 | 202 | " \"\"\"Download MNIST dataset and convert it to numpy array\n",
|
193 | 203 | "\n",
|
194 | 204 | " Args:\n",
|
|
220 | 230 | " return\n",
|
221 | 231 | "\n",
|
222 | 232 | "\n",
|
223 |
| - "download_from_s3(\"/tmp/data\", True)\n", |
224 |
| - "download_from_s3(\"/tmp/data\", False)" |
| 233 | + "download_from_s3(\"./data\", True)\n", |
| 234 | + "download_from_s3(\"./data\", False)" |
225 | 235 | ]
|
226 | 236 | },
|
227 | 237 | {
|
|
230 | 240 | "metadata": {},
|
231 | 241 | "outputs": [],
|
232 | 242 | "source": [
|
233 |
| - "# upload to the default bucket\n", |
| 243 | + "# Upload to the default bucket\n", |
234 | 244 | "\n",
|
235 |
| - "prefix = \"mnist\"\n", |
| 245 | + "prefix = \"DEMO-mnist\"\n", |
236 | 246 | "bucket = sess.default_bucket()\n",
|
237 |
| - "loc = sess.upload_data(path=\"/tmp/data\", bucket=bucket, key_prefix=prefix)\n", |
| 247 | + "loc = sess.upload_data(path=\"./data\", bucket=bucket, key_prefix=prefix)\n", |
238 | 248 | "\n",
|
239 | 249 | "channels = {\"training\": loc, \"testing\": loc}"
|
240 | 250 | ]
|
|
243 | 253 | "cell_type": "markdown",
|
244 | 254 | "metadata": {},
|
245 | 255 | "source": [
|
246 |
| - "The keys of the dictionary `channels` are parsed to the training image\n", |
| 256 | + "The keys of the `channels` dictionary are passed to the training image,\n", |
247 | 257 | "and it creates the environment variable `SM_CHANNEL_<key name>`. \n",
|
248 | 258 | "\n",
|
249 |
| - "In this example, `SM_CHANNEL_TRAINING` and `SM_CHANNEL_TESTING` are created in the training image (checkout \n", |
250 |
| - "how `code/train.py` access these variables). For more information,\n", |
251 |
| - "see: [SM_CHANNEL_{channel_name}](https://github.com/aws/sagemaker-training-toolkit/blob/master/ENVIRONMENT_VARIABLES.md#sm_channel_channel_name)\n", |
| 259 | + "In this example, `SM_CHANNEL_TRAINING` and `SM_CHANNEL_TESTING` are created in the training image (see \n", |
| 260 | + "how `code/train.py` accesses these variables). For more information,\n", |
| 261 | + "see: [SM_CHANNEL_{channel_name}](https://github.com/aws/sagemaker-training-toolkit/blob/master/ENVIRONMENT_VARIABLES.md#sm_channel_channel_name).\n", |
252 | 262 | "\n",
|
253 | 263 | "If you want, you can create a channel for validation:\n",
|
254 | 264 | "```\n",
|
255 | 265 | "channels = {\n",
|
256 | 266 | " 'training': train_data_loc,\n",
|
257 | 267 | " 'validation': val_data_loc,\n",
|
258 | 268 | " 'test': test_data_loc\n",
|
259 |
| - " }\n", |
| 269 | + "}\n", |
260 | 270 | "```\n",
|
261 | 271 | "You can then access this channel within your training script via\n",
|
262 |
| - "`SM_CHANNEL_VALIDATION`\n" |
| 272 | + "`SM_CHANNEL_VALIDATION`.\n" |
263 | 273 | ]
|
264 | 274 | },
|
265 | 275 | {
|
|
268 | 278 | "source": [
|
269 | 279 | "## Run the training script on SageMaker\n",
|
270 | 280 | "Now, the training container has everything to execute your training\n",
|
271 |
| - "script. You can start the container by calling `fit` method." |
| 281 | + "script. Start the container by calling the `fit()` method." |
272 | 282 | ]
|
273 | 283 | },
|
274 | 284 | {
|
|
288 | 298 | "source": [
|
289 | 299 | "## Inspect and store model data\n",
|
290 | 300 | "\n",
|
291 |
| - "Now, the training is finished, the model artifact has been saved in \n", |
292 |
| - "the `output_path`. We " |
| 301 | + "Now, the training is finished, and the model artifact has been saved in \n", |
| 302 | + "the `output_path`." |
293 | 303 | ]
|
294 | 304 | },
|
295 | 305 | {
|
|
306 | 316 | "cell_type": "markdown",
|
307 | 317 | "metadata": {},
|
308 | 318 | "source": [
|
309 |
| - "We store the variable `model_data` in the current notebook kernel. \n", |
310 |
| - "In the [next notebook](get_started_with_mnist_deploy.ipynb), you will learn how to retrieve the model artifact and deploy to a SageMaker\n", |
311 |
| - "endpoint." |
| 319 | + "We store the variable `pt_mnist_model_data` in the current notebook kernel." |
312 | 320 | ]
|
313 | 321 | },
|
314 | 322 | {
|
|
326 | 334 | "source": [
|
327 | 335 | "## Test and debug the entry point before executing the training container\n",
|
328 | 336 | "\n",
|
329 |
| - "The entry point `code/train.py` provided here has been tested and it can be executed in the training container. \n", |
330 |
| - "When you do develop your own training script, it is a good practice to simulate the container environment \n", |
| 337 | + "The entry point `code/train.py` can be executed in the training container. \n", |
| 338 | + "When you develop your own training script, it is a good practice to simulate the container environment \n", |
331 | 339 | "in the local shell and test it before sending it to SageMaker, because debugging in a containerized environment\n",
|
332 | 340 | "is rather cumbersome. The following script shows how you can test your training script:"
|
333 | 341 | ]
|
|
340 | 348 | "source": [
|
341 | 349 | "!pygmentize code/test_train.py"
|
342 | 350 | ]
|
| 351 | + }, |
| 352 | + { |
| 353 | + "cell_type": "markdown", |
| 354 | + "metadata": {}, |
| 355 | + "source": [ |
| 356 | + "## Conclusion\n", |
| 357 | + "\n", |
| 358 | + "In this notebook, we trained a PyTorch model on the MNIST dataset by fitting a SageMaker estimator. For next steps on how to deploy the trained model and perform inference, see [Deploy a Trained PyTorch Model](https://sagemaker-examples.readthedocs.io/en/latest/frameworks/pytorch/get_started_mnist_deploy.html)." |
| 359 | + ] |
343 | 360 | }
|
344 | 361 | ],
|
345 | 362 | "metadata": {
|
346 | 363 | "kernelspec": {
|
347 |
| - "display_name": "Environment (conda_pytorch_p36)", |
| 364 | + "display_name": "conda_pytorch_p36", |
348 | 365 | "language": "python",
|
349 | 366 | "name": "conda_pytorch_p36"
|
350 | 367 | },
|
|
358 | 375 | "name": "python",
|
359 | 376 | "nbconvert_exporter": "python",
|
360 | 377 | "pygments_lexer": "ipython3",
|
361 |
| - "version": "3.6.10" |
| 378 | + "version": "3.6.13" |
362 | 379 | },
|
363 | 380 | "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
|
364 | 381 | },
|
|
0 commit comments