Skip to content

Commit 6365190

Browse files
Add smdataparallel mnist pytorch and tensorflow2 examples (#97)
1 parent 071cf58 commit 6365190

File tree

8 files changed

+1043
-0
lines changed

8 files changed

+1043
-0
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import print_function
19+
20+
import os
21+
import torch
22+
23+
# Network definition
24+
from model_def import Net
25+
26+
def model_fn(model_dir):
27+
print("In model_fn. Model directory is -")
28+
print(model_dir)
29+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30+
model = Net()
31+
with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
32+
print("Loading the mnist model")
33+
model.load_state_dict(torch.load(f, map_location=device))
34+
return model
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import torch
19+
import torch.nn.functional as F
20+
import torch.nn as nn
21+
22+
class Net(nn.Module):
23+
def __init__(self):
24+
super(Net, self).__init__()
25+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
26+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
27+
self.dropout1 = nn.Dropout2d(0.25)
28+
self.dropout2 = nn.Dropout2d(0.5)
29+
self.fc1 = nn.Linear(9216, 128)
30+
self.fc2 = nn.Linear(128, 10)
31+
32+
def forward(self, x):
33+
x = self.conv1(x)
34+
x = F.relu(x)
35+
x = self.conv2(x)
36+
x = F.relu(x)
37+
x = F.max_pool2d(x, 2)
38+
x = self.dropout1(x)
39+
x = torch.flatten(x, 1)
40+
x = self.fc1(x)
41+
x = F.relu(x)
42+
x = self.dropout2(x)
43+
x = self.fc2(x)
44+
output = F.log_softmax(x, dim=1)
45+
return output
46+
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at
4+
#
5+
# http://aws.amazon.com/apache2.0/
6+
#
7+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and limitations under the License.
8+
9+
from __future__ import print_function
10+
11+
import os
12+
import argparse
13+
import time
14+
import torch
15+
import torch.nn.functional as F
16+
import torch.optim as optim
17+
import torch.nn as nn
18+
from torchvision import datasets, transforms
19+
from torch.optim.lr_scheduler import StepLR
20+
21+
# Network definition
22+
from model_def import Net
23+
24+
# Import SMDataParallel PyTorch Modules
25+
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
26+
import smdistributed.dataparallel.torch.distributed as dist
27+
28+
dist.init_process_group()
29+
30+
def train(args, model, device, train_loader, optimizer, epoch):
31+
model.train()
32+
for batch_idx, (data, target) in enumerate(train_loader):
33+
data, target = data.to(device), target.to(device)
34+
optimizer.zero_grad()
35+
output = model(data)
36+
loss = F.nll_loss(output, target)
37+
loss.backward()
38+
optimizer.step()
39+
if batch_idx % args.log_interval == 0 and args.rank == 0:
40+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
41+
epoch, batch_idx * len(data) * args.world_size, len(train_loader.dataset),
42+
100. * batch_idx / len(train_loader), loss.item()))
43+
if args.verbose:
44+
print('Batch', batch_idx, "from rank", args.rank)
45+
46+
47+
def test(model, device, test_loader):
48+
model.eval()
49+
test_loss = 0
50+
correct = 0
51+
with torch.no_grad():
52+
for data, target in test_loader:
53+
data, target = data.to(device), target.to(device)
54+
output = model(data)
55+
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
56+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
57+
correct += pred.eq(target.view_as(pred)).sum().item()
58+
59+
test_loss /= len(test_loader.dataset)
60+
61+
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
62+
test_loss, correct, len(test_loader.dataset),
63+
100. * correct / len(test_loader.dataset)))
64+
65+
66+
def save_model(model, model_dir):
67+
with open(os.path.join(model_dir, 'model.pth'), 'wb') as f:
68+
torch.save(model.module.state_dict(), f)
69+
70+
def main():
71+
# Training settings
72+
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
73+
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
74+
help='input batch size for training (default: 64)')
75+
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
76+
help='input batch size for testing (default: 1000)')
77+
parser.add_argument('--epochs', type=int, default=14, metavar='N',
78+
help='number of epochs to train (default: 14)')
79+
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
80+
help='learning rate (default: 1.0)')
81+
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
82+
help='Learning rate step gamma (default: 0.7)')
83+
parser.add_argument('--seed', type=int, default=1, metavar='S',
84+
help='random seed (default: 1)')
85+
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
86+
help='how many batches to wait before logging training status')
87+
parser.add_argument('--save-model', action='store_true', default=False,
88+
help='For Saving the current Model')
89+
parser.add_argument('--verbose', action='store_true', default=False,
90+
help='For displaying SMDataParallel-specific logs')
91+
parser.add_argument('--data-path', type=str, default='/tmp/data', help='Path for downloading '
92+
'the MNIST dataset')
93+
# Model checkpoint location
94+
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
95+
96+
args = parser.parse_args()
97+
args.world_size = dist.get_world_size()
98+
args.rank = rank = dist.get_rank()
99+
args.local_rank = local_rank = dist.get_local_rank()
100+
args.lr = 1.0
101+
args.batch_size //= args.world_size // 8
102+
args.batch_size = max(args.batch_size, 1)
103+
data_path = args.data_path
104+
105+
106+
if args.verbose:
107+
print('Hello from rank', rank, 'of local_rank',
108+
local_rank, 'in world size of', args.world_size)
109+
110+
if not torch.cuda.is_available():
111+
raise Exception("Must run SMDataParallel MNIST example on CUDA-capable devices.")
112+
113+
torch.manual_seed(args.seed)
114+
115+
device = torch.device("cuda")
116+
117+
if local_rank == 0:
118+
train_dataset = datasets.MNIST(data_path, train=True, download=True,
119+
transform=transforms.Compose([
120+
transforms.ToTensor(),
121+
transforms.Normalize((0.1307,), (0.3081,))
122+
]))
123+
else:
124+
time.sleep(8)
125+
train_dataset = datasets.MNIST(data_path, train=True, download=False,
126+
transform=transforms.Compose([
127+
transforms.ToTensor(),
128+
transforms.Normalize((0.1307,), (0.3081,))
129+
]))
130+
131+
train_sampler = torch.utils.data.distributed.DistributedSampler(
132+
train_dataset,
133+
num_replicas=args.world_size,
134+
rank=rank)
135+
train_loader = torch.utils.data.DataLoader(
136+
train_dataset,
137+
batch_size=args.batch_size,
138+
shuffle=False,
139+
num_workers=0,
140+
pin_memory=True,
141+
sampler=train_sampler)
142+
if rank == 0:
143+
test_loader = torch.utils.data.DataLoader(
144+
datasets.MNIST(data_path, train=False, transform=transforms.Compose([
145+
transforms.ToTensor(),
146+
transforms.Normalize((0.1307,), (0.3081,))
147+
])),
148+
batch_size=args.test_batch_size, shuffle=True)
149+
150+
# Use SMDataParallel PyTorch DDP for efficient distributed training
151+
model = DDP(Net().to(device))
152+
torch.cuda.set_device(local_rank)
153+
model.cuda(local_rank)
154+
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
155+
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
156+
for epoch in range(1, args.epochs + 1):
157+
train(args, model, device, train_loader, optimizer, epoch)
158+
if rank == 0:
159+
test(model, device, test_loader)
160+
scheduler.step()
161+
162+
print("Saving the model...")
163+
if rank == 0:
164+
save_model(model, args.model_dir)
165+
166+
167+
if __name__ == '__main__':
168+
main()
169+

0 commit comments

Comments
 (0)