Skip to content

Commit ea3da6e

Browse files
authored
upgrade MNIST experiment notebook to SDK v2 (#1576)
1 parent db08c55 commit ea3da6e

File tree

2 files changed

+229
-123
lines changed

2 files changed

+229
-123
lines changed

sagemaker-experiments/mnist-handwritten-digits-classification-experiment/mnist-handwritten-digits-classification-experiment.ipynb

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
"metadata": {},
4545
"outputs": [],
4646
"source": [
47-
"!{sys.executable} -m pip install sagemaker-experiments"
47+
"!{sys.executable} -m pip install sagemaker-experiments==0.1.24"
4848
]
4949
},
5050
{
@@ -60,8 +60,12 @@
6060
"metadata": {},
6161
"outputs": [],
6262
"source": [
63-
"!{sys.executable} -m pip install torch\n",
64-
"!{sys.executable} -m pip install torchvision"
63+
"# pytorch version needs to be the same in both the notebook instance and the training job container \n",
64+
"# https://github.com/pytorch/pytorch/issues/25214\n",
65+
"!{sys.executable} -m pip install torch==1.1.0\n",
66+
"!{sys.executable} -m pip install torchvision==0.3.0\n",
67+
"!{sys.executable} -m pip install pillow==6.2.2 ",
68+
"!{sys.executable} -m pip install --upgrade sagemaker",
6569
]
6670
},
6771
{
@@ -73,7 +77,7 @@
7377
},
7478
{
7579
"cell_type": "code",
76-
"execution_count": null,
80+
"execution_count": 1,
7781
"metadata": {},
7882
"outputs": [],
7983
"source": [
@@ -82,7 +86,7 @@
8286
"import boto3\n",
8387
"import numpy as np\n",
8488
"import pandas as pd\n",
85-
"%config InlineBackend.figure_format = 'retina'\n",
89+
"from IPython.display import set_matplotlib_formats\n",
8690
"from matplotlib import pyplot as plt\n",
8791
"from torchvision import datasets, transforms\n",
8892
"\n",
@@ -94,7 +98,9 @@
9498
"from smexperiments.experiment import Experiment\n",
9599
"from smexperiments.trial import Trial\n",
96100
"from smexperiments.trial_component import TrialComponent\n",
97-
"from smexperiments.tracker import Tracker"
101+
"from smexperiments.tracker import Tracker\n",
102+
"\n",
103+
"set_matplotlib_formats('retina')"
98104
]
99105
},
100106
{
@@ -307,12 +313,13 @@
307313
" # all input configurations, parameters, and metrics specified in estimator \n",
308314
" # definition are automatically tracked\n",
309315
" estimator = PyTorch(\n",
316+
" py_version='py3',\n",
310317
" entry_point='./mnist.py',\n",
311318
" role=role,\n",
312319
" sagemaker_session=sagemaker.Session(sagemaker_client=sm),\n",
313320
" framework_version='1.1.0',\n",
314-
" train_instance_count=1,\n",
315-
" train_instance_type='ml.c4.xlarge',\n",
321+
" instance_count=1,\n",
322+
" instance_type='ml.c4.xlarge',\n",
316323
" hyperparameters={\n",
317324
" 'epochs': 2,\n",
318325
" 'backend': 'gloo',\n",
@@ -470,6 +477,7 @@
470477
" model_data, \n",
471478
" role, \n",
472479
" './mnist.py', \n",
480+
" py_version='py3',\n",
473481
" env=env, \n",
474482
" sagemaker_session=sagemaker.Session(sagemaker_client=sm),\n",
475483
" framework_version='1.1.0',\n",
@@ -497,39 +505,29 @@
497505
"metadata": {},
498506
"outputs": [],
499507
"source": [
500-
"predictor.delete_endpoint()\n",
501-
"\n",
502-
"def cleanup(experiment):\n",
503-
" for trial_summary in experiment.list_trials():\n",
504-
" trial = Trial.load(sagemaker_boto_client=sm, trial_name=trial_summary.trial_name)\n",
505-
" for trial_component_summary in trial.list_trial_components():\n",
506-
" tc = TrialComponent.load(\n",
507-
" sagemaker_boto_client=sm,\n",
508-
" trial_component_name=trial_component_summary.trial_component_name)\n",
509-
" trial.remove_trial_component(tc)\n",
510-
" try:\n",
511-
" # comment out to keep trial components\n",
512-
" tc.delete()\n",
513-
" except:\n",
514-
" # tc is associated with another trial\n",
515-
" continue\n",
516-
" # to prevent throttling\n",
517-
" time.sleep(.5)\n",
518-
" trial.delete()\n",
519-
" experiment.delete()\n",
520-
"\n",
521-
"cleanup(mnist_experiment)"
508+
"predictor.delete_endpoint()"
522509
]
523510
},
524511
{
525512
"cell_type": "code",
526513
"execution_count": null,
527514
"metadata": {},
528515
"outputs": [],
529-
"source": []
516+
"source": [
517+
"mnist_experiment.delete_all(action='--force')"
518+
]
519+
},
520+
{
521+
"cell_type": "markdown",
522+
"metadata": {},
523+
"source": [
524+
"## Contact\n",
525+
"Submit any questions or issues to https://github.com/aws/sagemaker-experiments/issues or mention @aws/sagemakerexperimentsadmin "
526+
]
530527
}
531528
],
532529
"metadata": {
530+
"instance_type": "ml.t3.medium",
533531
"kernelspec": {
534532
"display_name": "Python 3 (Data Science)",
535533
"language": "python",

0 commit comments

Comments
 (0)