Skip to content

Commit 54b3fec

Browse files
authored
Add notebooks for Batch Transform (aws#326)
three notebooks (MXNet, TF, BYO) demonstrating the new Batch Transform functionality
1 parent b4a679f commit 54b3fec

File tree

9 files changed

+2044
-0
lines changed

9 files changed

+2044
-0
lines changed

batch_transform/mxnet_mnist/mnist.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import logging
2+
3+
import gzip
4+
import mxnet as mx
5+
import numpy as np
6+
import os
7+
import struct
8+
9+
10+
def load_data(path):
11+
with gzip.open(find_file(path, "labels.gz")) as flbl:
12+
struct.unpack(">II", flbl.read(8))
13+
labels = np.fromstring(flbl.read(), dtype=np.int8)
14+
with gzip.open(find_file(path, "images.gz")) as fimg:
15+
_, _, rows, cols = struct.unpack(">IIII", fimg.read(16))
16+
images = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(labels), rows, cols)
17+
images = images.reshape(images.shape[0], 1, 28, 28).astype(np.float32) / 255
18+
return labels, images
19+
20+
21+
def find_file(root_path, file_name):
22+
for root, dirs, files in os.walk(root_path):
23+
if file_name in files:
24+
return os.path.join(root, file_name)
25+
26+
27+
def build_graph():
28+
data = mx.sym.var('data')
29+
data = mx.sym.flatten(data=data)
30+
fc1 = mx.sym.FullyConnected(data=data, num_hidden=128)
31+
act1 = mx.sym.Activation(data=fc1, act_type="relu")
32+
fc2 = mx.sym.FullyConnected(data=act1, num_hidden=64)
33+
act2 = mx.sym.Activation(data=fc2, act_type="relu")
34+
fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)
35+
return mx.sym.SoftmaxOutput(data=fc3, name='softmax')
36+
37+
38+
def train(current_host, channel_input_dirs, hyperparameters, hosts, num_cpus, num_gpus):
39+
(train_labels, train_images) = load_data(os.path.join(channel_input_dirs['train']))
40+
(test_labels, test_images) = load_data(os.path.join(channel_input_dirs['test']))
41+
42+
# Alternatively to splitting in memory, the data could be pre-split in S3 and use ShardedByS3Key
43+
# to do parallel training.
44+
shard_size = len(train_images) // len(hosts)
45+
for i, host in enumerate(hosts):
46+
if host == current_host:
47+
start = shard_size * i
48+
end = start + shard_size
49+
break
50+
51+
batch_size = 100
52+
train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size, shuffle=True)
53+
val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size)
54+
logging.getLogger().setLevel(logging.DEBUG)
55+
kvstore = 'local' if len(hosts) == 1 else 'dist_sync'
56+
mlp_model = mx.mod.Module(
57+
symbol=build_graph(),
58+
context=get_train_context(num_cpus, num_gpus))
59+
mlp_model.fit(train_iter,
60+
eval_data=val_iter,
61+
kvstore=kvstore,
62+
optimizer='sgd',
63+
optimizer_params={'learning_rate': float(hyperparameters.get("learning_rate", 0.1))},
64+
eval_metric='acc',
65+
batch_end_callback=mx.callback.Speedometer(batch_size, 100),
66+
num_epoch=25)
67+
return mlp_model
68+
69+
70+
def get_train_context(num_cpus, num_gpus):
71+
if num_gpus > 0:
72+
return mx.gpu()
73+
return mx.cpu()

0 commit comments

Comments
 (0)