Skip to content

Commit 26b5182

Browse files
author
Dan
authored
update pytorch_local_mode_cifar10.ipynb to 1.7.1 (#2078)
1 parent d539d4e commit 26b5182

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

sagemaker-python-sdk/pytorch_cnn_cifar10/pytorch_local_mode_cifar10.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@
199199
"\n",
200200
"cifar10_estimator = PyTorch(entry_point='source/cifar10.py',\n",
201201
" role=role,\n",
202-
" framework_version='1.4.0',\n",
202+
" framework_version='1.7.1',\n",
203203
" train_instance_count=1,\n",
204204
" train_instance_type=instance_type)\n",
205205
"\n",

sagemaker-python-sdk/pytorch_cnn_cifar10/source/cifar10.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import argparse
22
import logging
3-
import sagemaker_containers
43

54
import os
65

@@ -16,6 +15,11 @@
1615
import torchvision.transforms as transforms
1716
import torch.nn.functional as F
1817

18+
try:
19+
from sagemaker_inference import environment
20+
except:
21+
from sagemaker_training import environment
22+
1923
logger = logging.getLogger(__name__)
2024
logger.setLevel(logging.DEBUG)
2125

@@ -150,7 +154,7 @@ def model_fn(model_dir):
150154
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='momentum (default: 0.9)')
151155
parser.add_argument('--dist_backend', type=str, default='gloo', help='distributed backend (default: gloo)')
152156

153-
env = sagemaker_containers.training_env()
157+
env = environment.Environment()
154158
parser.add_argument('--hosts', type=list, default=env.hosts)
155159
parser.add_argument('--current-host', type=str, default=env.current_host)
156160
parser.add_argument('--model-dir', type=str, default=env.model_dir)

0 commit comments

Comments
 (0)