Skip to content

Commit 67ec576

Browse files
authored
Support Mxnet Horovod (aws#373)
* support mxnet horovod * mxnet hvd example
1 parent 7e534d1 commit 67ec576

File tree

2 files changed

+215
-0
lines changed

2 files changed

+215
-0
lines changed
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import argparse
2+
import logging
3+
import os
4+
import zipfile
5+
import time
6+
7+
import mxnet as mx
8+
import horovod.mxnet as hvd
9+
from mxnet import autograd, gluon, nd
10+
from mxnet.test_utils import download
11+
12+
from tornasole import SaveConfig, modes
13+
from tornasole.mxnet import TornasoleHook
14+
15+
# Training settings
16+
parser = argparse.ArgumentParser(description="MXNet MNIST Example")
17+
18+
parser.add_argument("--batch-size", type=int, default=64, help="training batch size (default: 64)")
19+
parser.add_argument(
20+
"--dtype", type=str, default="float32", help="training data type (default: float32)"
21+
)
22+
parser.add_argument("--epochs", type=int, default=5, help="number of training epochs (default: 5)")
23+
parser.add_argument("--lr", type=float, default=0.01, help="learning rate (default: 0.01)")
24+
parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum (default: 0.9)")
25+
parser.add_argument(
26+
"--no-cuda", action="store_true", default=False, help="disable training on GPU (default: False)"
27+
)
28+
parser.add_argument(
29+
"--output-uri",
30+
type=str,
31+
default="/opt/ml/output/tensors/tornasole",
32+
help="S3 URI of the bucket where tensor data will be stored.",
33+
)
34+
args = parser.parse_args()
35+
36+
if not args.no_cuda:
37+
# Disable CUDA if there are no GPUs.
38+
if not mx.test_utils.list_gpus():
39+
args.no_cuda = True
40+
41+
logging.basicConfig(level=logging.INFO)
42+
logging.info(args)
43+
44+
45+
# Function to get mnist iterator given a rank
46+
def get_mnist_iterator(rank):
47+
data_dir = "data-%d" % rank
48+
if not os.path.isdir(data_dir):
49+
os.makedirs(data_dir)
50+
zip_file_path = download("http://data.mxnet.io/mxnet/data/mnist.zip", dirname=data_dir)
51+
with zipfile.ZipFile(zip_file_path) as zf:
52+
zf.extractall(data_dir)
53+
54+
input_shape = (1, 28, 28)
55+
batch_size = args.batch_size
56+
57+
train_iter = mx.io.MNISTIter(
58+
image="%s/train-images-idx3-ubyte" % data_dir,
59+
label="%s/train-labels-idx1-ubyte" % data_dir,
60+
input_shape=input_shape,
61+
batch_size=batch_size,
62+
shuffle=True,
63+
flat=False,
64+
num_parts=hvd.size(),
65+
part_index=hvd.rank(),
66+
)
67+
68+
val_iter = mx.io.MNISTIter(
69+
image="%s/t10k-images-idx3-ubyte" % data_dir,
70+
label="%s/t10k-labels-idx1-ubyte" % data_dir,
71+
input_shape=input_shape,
72+
batch_size=batch_size,
73+
flat=False,
74+
)
75+
76+
return train_iter, val_iter
77+
78+
79+
# Function to define neural network
80+
def conv_nets():
81+
net = gluon.nn.HybridSequential()
82+
with net.name_scope():
83+
net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation="relu"))
84+
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
85+
net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation="relu"))
86+
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
87+
net.add(gluon.nn.Flatten())
88+
net.add(gluon.nn.Dense(512, activation="relu"))
89+
net.add(gluon.nn.Dense(10))
90+
return net
91+
92+
93+
# Function to evaluate accuracy for a model
94+
def evaluate(model, data_iter, context):
95+
data_iter.reset()
96+
metric = mx.metric.Accuracy()
97+
for _, batch in enumerate(data_iter):
98+
data = batch.data[0].as_in_context(context)
99+
label = batch.label[0].as_in_context(context)
100+
output = model(data.astype(args.dtype, copy=False))
101+
metric.update([label], [output])
102+
103+
return metric.get()
104+
105+
106+
# Initialize Horovod
107+
hvd.init()
108+
109+
# Horovod: pin context to local rank
110+
context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(hvd.local_rank())
111+
num_workers = hvd.size()
112+
113+
# Load training and validation data
114+
train_data, val_data = get_mnist_iterator(hvd.rank())
115+
116+
# Build model
117+
model = conv_nets()
118+
model.cast(args.dtype)
119+
model.hybridize()
120+
121+
# Create optimizer
122+
optimizer_params = {"momentum": args.momentum, "learning_rate": args.lr * hvd.size()}
123+
opt = mx.optimizer.create("sgd", **optimizer_params)
124+
# opt = ts.TornasoleOptimizer(opt)
125+
126+
# Initialize parameters
127+
initializer = mx.init.Xavier(rnd_type="gaussian", factor_type="in", magnitude=2)
128+
model.initialize(initializer, ctx=context)
129+
130+
# Horovod: fetch and broadcast parameters
131+
params = model.collect_params()
132+
if params is not None:
133+
hvd.broadcast_parameters(params, root_rank=0)
134+
135+
# Horovod: create DistributedTrainer, a subclass of gluon.Trainer
136+
trainer = hvd.DistributedTrainer(params, opt)
137+
138+
# Create loss function and train metric
139+
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
140+
metric = mx.metric.Accuracy()
141+
142+
143+
def create_tornasole_hook():
144+
# With the following SaveConfig, we will save tensors for steps 1, 2 and 3
145+
# (indexing starts with 0).
146+
save_config = SaveConfig(save_interval=1)
147+
148+
# Create a hook that logs weights, biases and gradients while training the model.
149+
ts_hook = TornasoleHook(
150+
out_dir=args.output_uri,
151+
save_config=save_config,
152+
include_collections=["weights", "gradients", "biases"],
153+
)
154+
return ts_hook
155+
156+
157+
# Train model
158+
for epoch in range(args.epochs):
159+
tic = time.time()
160+
train_data.reset()
161+
metric.reset()
162+
163+
# Create Tornasole Hook
164+
hook = create_tornasole_hook()
165+
hook.register_hook(model)
166+
167+
for nbatch, batch in enumerate(train_data, start=1):
168+
hook.set_mode(modes.TRAIN)
169+
data = batch.data[0].as_in_context(context)
170+
label = batch.label[0].as_in_context(context)
171+
with autograd.record():
172+
output = model(data.astype(args.dtype, copy=False))
173+
loss = loss_fn(output, label)
174+
loss.backward()
175+
trainer.step(args.batch_size)
176+
metric.update([label], [output])
177+
178+
if nbatch % 100 == 0:
179+
name, acc = metric.get()
180+
logging.info("[Epoch %d Batch %d] Training: %s=%f" % (epoch, nbatch, name, acc))
181+
182+
if hvd.rank() == 0:
183+
elapsed = time.time() - tic
184+
speed = nbatch * args.batch_size * hvd.size() / elapsed
185+
logging.info("Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f", epoch, speed, elapsed)
186+
187+
# Evaluate model accuracy
188+
hook.set_mode(modes.EVAL)
189+
_, train_acc = metric.get()
190+
name, val_acc = evaluate(model, val_data, context)
191+
if hvd.rank() == 0:
192+
logging.info(
193+
"Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f", epoch, name, train_acc, name, val_acc
194+
)
195+
196+
if hvd.rank() == 0 and epoch == args.epochs - 1:
197+
assert val_acc > 0.96, (
198+
"Achieved accuracy (%f) is lower than expected\
199+
(0.96)"
200+
% val_acc
201+
)

tornasole/mxnet/hook.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,23 @@ def __init__(
5959
set_hook(self)
6060

6161
def get_worker_name(self):
62+
try:
63+
import horovod.mxnet as hvd
64+
65+
if hvd.size():
66+
return f"worker_{hvd.rank()}"
67+
except (ModuleNotFoundError, ValueError, ImportError):
68+
pass
6269
return CONFIG_DEFAULT_WORKER_NAME
6370

6471
def get_num_workers(self):
72+
try:
73+
import horovod.mxnet as hvd
74+
75+
if hvd.size():
76+
return hvd.size()
77+
except (ModuleNotFoundError, ValueError, ImportError):
78+
pass
6579
return 1
6680

6781
@classmethod

0 commit comments

Comments
 (0)