Skip to content

Commit 1ef46cc

Browse files
authored
Merge pull request aws#365 from ChoiByungWook/pytorch_extend
Extending SageMaker PyTorch containers
2 parents 7cc1278 + 45006a1 commit 1ef46cc

File tree

9 files changed

+1100
-0
lines changed

9 files changed

+1100
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
# For more information on creating a Dockerfile
15+
# https://docs.docker.com/compose/gettingstarted/#step-2-create-a-dockerfile
16+
# https://github.com/awslabs/amazon-sagemaker-examples/master/advanced_functionality/pytorch_extending_our_containers/pytorch_extending_our_containers.ipynb
17+
# SageMaker PyTorch image
18+
FROM 520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch:0.4.0-cpu-py3
19+
20+
ENV PATH="/opt/ml/code:${PATH}"
21+
22+
# /opt/ml and all subdirectories are utilized by SageMaker, we use the /code subdirectory to store our user code.
23+
COPY /cifar10 /opt/ml/code
24+
25+
# this environment variable is used by the SageMaker PyTorch container to determine our user code directory.
26+
ENV SAGEMAKER_SUBMIT_DIRECTORY /opt/ml/code
27+
28+
# this environment variable is used by the SageMaker PyTorch container to determine our program entry point
29+
# for training and serving.
30+
# For more information: https://github.com/aws/sagemaker-pytorch-container
31+
ENV SAGEMAKER_PROGRAM cifar10.py
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/usr/bin/env bash
2+
3+
# This script shows how to build the Docker image and push it to ECR to be ready for use
4+
# by SageMaker.
5+
6+
# The argument to this script is the image name. This will be used as the image on the local
7+
# machine and combined with the account and region to form the repository name for ECR.
8+
image=$1
9+
10+
if [ "$image" == "" ]
11+
then
12+
echo "Usage: $0 <image-name>"
13+
exit 1
14+
fi
15+
16+
# Get the account number associated with the current IAM credentials
17+
account=$(aws sts get-caller-identity --query Account --output text)
18+
19+
if [ $? -ne 0 ]
20+
then
21+
exit 255
22+
fi
23+
24+
25+
# Get the region defined in the current configuration (default to us-west-2 if none defined)
26+
region=$(aws configure get region)
27+
region=${region:-us-west-2}
28+
29+
30+
fullname="${account}.dkr.ecr.${region}.amazonaws.com/${image}:latest"
31+
32+
# If the repository doesn't exist in ECR, create it.
33+
34+
aws ecr describe-repositories --repository-names "${image}" > /dev/null 2>&1
35+
36+
if [ $? -ne 0 ]
37+
then
38+
aws ecr create-repository --repository-name "${image}" > /dev/null
39+
fi
40+
41+
# Get the login command from ECR and execute it directly
42+
$(aws ecr get-login --region ${region} --no-include-email)
43+
44+
# Get the login command from ECR in order to pull down the SageMaker PyTorch image
45+
$(aws ecr get-login --registry-ids 520713654638 --region ${region} --no-include-email)
46+
47+
# Build the docker image locally with the image name and then push it to ECR
48+
# with the full name.
49+
50+
docker build -t ${image} .
51+
docker tag ${image} ${fullname}
52+
53+
docker push ${fullname}
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import ast
14+
import argparse
15+
import logging
16+
17+
import os
18+
19+
import torch
20+
import torch.distributed as dist
21+
import torch.nn as nn
22+
import torch.nn.parallel
23+
import torch.optim
24+
import torch.utils.data
25+
import torch.utils.data.distributed
26+
import torchvision
27+
import torchvision.models
28+
import torchvision.transforms as transforms
29+
import torch.nn.functional as F
30+
31+
logger = logging.getLogger(__name__)
32+
logger.setLevel(logging.DEBUG)
33+
34+
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
35+
36+
37+
# https://github.com/pytorch/tutorials/blob/master/beginner_source/blitz/cifar10_tutorial.py#L118
38+
class Net(nn.Module):
39+
def __init__(self):
40+
super(Net, self).__init__()
41+
self.conv1 = nn.Conv2d(3, 6, 5)
42+
self.pool = nn.MaxPool2d(2, 2)
43+
self.conv2 = nn.Conv2d(6, 16, 5)
44+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
45+
self.fc2 = nn.Linear(120, 84)
46+
self.fc3 = nn.Linear(84, 10)
47+
48+
def forward(self, x):
49+
x = self.pool(F.relu(self.conv1(x)))
50+
x = self.pool(F.relu(self.conv2(x)))
51+
x = x.view(-1, 16 * 5 * 5)
52+
x = F.relu(self.fc1(x))
53+
x = F.relu(self.fc2(x))
54+
x = self.fc3(x)
55+
return x
56+
57+
58+
def _train(args):
59+
is_distributed = len(args.hosts) > 1 and args.dist_backend is not None
60+
logger.debug("Distributed training - {}".format(is_distributed))
61+
62+
if is_distributed:
63+
# Initialize the distributed environment.
64+
world_size = len(args.hosts)
65+
os.environ['WORLD_SIZE'] = str(world_size)
66+
host_rank = args.hosts.index(args.current_host)
67+
dist.init_process_group(backend=args.dist_backend, rank=host_rank, world_size=world_size)
68+
logger.info(
69+
'Initialized the distributed environment: \'{}\' backend on {} nodes. '.format(
70+
args.dist_backend,
71+
dist.get_world_size()) + 'Current host rank is {}. Using cuda: {}. Number of gpus: {}'.format(
72+
dist.get_rank(), torch.cuda.is_available(), args.num_gpus))
73+
74+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
75+
logger.info("Device Type: {}".format(device))
76+
77+
logger.info("Loading Cifar10 dataset")
78+
transform = transforms.Compose(
79+
[transforms.ToTensor(),
80+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
81+
82+
trainset = torchvision.datasets.CIFAR10(root=args.data_dir, train=True,
83+
download=False, transform=transform)
84+
train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
85+
shuffle=True, num_workers=args.workers)
86+
87+
testset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False,
88+
download=False, transform=transform)
89+
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
90+
shuffle=False, num_workers=args.workers)
91+
92+
logger.info("Model loaded")
93+
model = Net()
94+
95+
if torch.cuda.device_count() > 1:
96+
logger.info("Gpu count: {}".format(torch.cuda.device_count()))
97+
model = nn.DataParallel(model)
98+
99+
model = model.to(device)
100+
101+
criterion = nn.CrossEntropyLoss().to(device)
102+
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
103+
104+
for epoch in range(0, args.epochs):
105+
running_loss = 0.0
106+
for i, data in enumerate(train_loader):
107+
# get the inputs
108+
inputs, labels = data
109+
inputs, labels = inputs.to(device), labels.to(device)
110+
111+
# zero the parameter gradients
112+
optimizer.zero_grad()
113+
114+
# forward + backward + optimize
115+
outputs = model(inputs)
116+
loss = criterion(outputs, labels)
117+
loss.backward()
118+
optimizer.step()
119+
120+
# print statistics
121+
running_loss += loss.item()
122+
if i % 2000 == 1999: # print every 2000 mini-batches
123+
print('[%d, %5d] loss: %.3f' %
124+
(epoch + 1, i + 1, running_loss / 2000))
125+
running_loss = 0.0
126+
print('Finished Training')
127+
return _save_model(model, args.model_dir)
128+
129+
130+
def _save_model(model, model_dir):
131+
logger.info("Saving the model.")
132+
path = os.path.join(model_dir, 'model.pth')
133+
# recommended way from http://pytorch.org/docs/master/notes/serialization.html
134+
torch.save(model.cpu().state_dict(), path)
135+
136+
137+
def model_fn(model_dir):
138+
logger.info('model_fn')
139+
device = "cuda" if torch.cuda.is_available() else "cpu"
140+
model = Net()
141+
if torch.cuda.device_count() > 1:
142+
logger.info("Gpu count: {}".format(torch.cuda.device_count()))
143+
model = nn.DataParallel(model)
144+
145+
with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
146+
model.load_state_dict(torch.load(f))
147+
return model.to(device)
148+
149+
150+
if __name__ == '__main__':
151+
parser = argparse.ArgumentParser()
152+
153+
parser.add_argument('--workers', type=int, default=2, metavar='W',
154+
help='number of data loading workers (default: 2)')
155+
parser.add_argument('--epochs', type=int, default=2, metavar='E',
156+
help='number of total epochs to run (default: 2)')
157+
parser.add_argument('--batch-size', type=int, default=4, metavar='BS',
158+
help='batch size (default: 4)')
159+
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
160+
help='initial learning rate (default: 0.001)')
161+
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='momentum (default: 0.9)')
162+
parser.add_argument('--dist-backend', type=str, default='gloo', help='distributed backend (default: gloo)')
163+
164+
# The parameters below retrieve their default values from SageMaker environment variables, which are
165+
# instantiated by the SageMaker containers framework.
166+
# https://github.com/aws/sagemaker-containers#how-a-script-is-executed-inside-the-container
167+
parser.add_argument('--hosts', type=str, default=ast.literal_eval(os.environ['SM_HOSTS']))
168+
parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST'])
169+
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
170+
parser.add_argument('--data-dir', type=str, default=os.environ['SM_CHANNEL_TRAINING'])
171+
parser.add_argument('--num-gpus', type=int, default=os.environ['SM_NUM_GPUS'])
172+
173+
_train(parser.parse_args())

0 commit comments

Comments
 (0)