Skip to content

Ensure MXNet notebooks run in distributed mode. #191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions sagemaker-python-sdk/mxnet_gluon_cifar10/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,17 @@ def train(current_host, hosts, num_cpus, num_gpus, channel_input_dirs, model_dir
# load training and validation data
# we use the gluon.data.vision.CIFAR10 class because of its built in pre-processing logic,
# but point it at the location where SageMaker placed the data files, so it doesn't download them again.

part_index = 0
for i, host in enumerate(hosts):
if host == current_host:
part_index = i
break


data_dir = channel_input_dirs['training']
train_data = get_train_data(num_cpus, data_dir, batch_size, (3, 32, 32))
train_data = get_train_data(num_cpus, data_dir, batch_size, (3, 32, 32),
num_parts=len(hosts), part_index=part_index)
test_data = get_test_data(num_cpus, data_dir, batch_size, (3, 32, 32))

# Collect all parameters from net and its children, then initialize them.
Expand Down Expand Up @@ -104,23 +113,26 @@ def save(net, model_dir):
os.rename(os.path.join(model_dir, best), os.path.join(model_dir, 'model.params'))


def get_data(path, augment, num_cpus, batch_size, data_shape, resize=-1):
def get_data(path, augment, num_cpus, batch_size, data_shape, resize=-1, num_parts=1, part_index=0):
return mx.io.ImageRecordIter(
path_imgrec=path,
resize=resize,
data_shape=data_shape,
batch_size=batch_size,
rand_crop=augment,
rand_mirror=augment,
preprocess_threads=num_cpus)
preprocess_threads=num_cpus,
num_parts=num_parts,
part_index=part_index)


def get_test_data(num_cpus, data_dir, batch_size, data_shape, resize=-1):
return get_data(os.path.join(data_dir, "test.rec"), False, num_cpus, batch_size, data_shape, resize)
return get_data(os.path.join(data_dir, "test.rec"), False, num_cpus, batch_size, data_shape, resize, 1, 0)


def get_train_data(num_cpus, data_dir, batch_size, data_shape, resize=-1):
return get_data(os.path.join(data_dir, "train.rec"), True, num_cpus, batch_size, data_shape, resize)
def get_train_data(num_cpus, data_dir, batch_size, data_shape, resize=-1, num_parts=1, part_index=0):
return get_data(os.path.join(data_dir, "train.rec"), True, num_cpus, batch_size, data_shape, resize, num_parts,
part_index)


def test(ctx, net, test_data):
Expand Down
15 changes: 14 additions & 1 deletion sagemaker-python-sdk/mxnet_gluon_mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# ------------------------------------------------------------ #


def train(channel_input_dirs, hyperparameters, hosts, num_gpus, **kwargs):
def train(current_host, channel_input_dirs, hyperparameters, hosts, num_gpus):
# SageMaker passes num_cpus, num_gpus and other args we can use to tailor training to
# the current container environment, but here we just use simple cpu context.
ctx = mx.cpu()
Expand Down Expand Up @@ -53,6 +53,19 @@ def train(channel_input_dirs, hyperparameters, hosts, num_gpus, **kwargs):
metric = mx.metric.Accuracy()
loss = gluon.loss.SoftmaxCrossEntropyLoss()

# shard the training data in case we are doing distributed training. Alternatively to splitting in memory,
# the data could be pre-split in S3 and use ShardedByS3Key to do distributed training.
if len(hosts) > 1:
train_data = [x for x in train_data]
shard_size = len(train_data) // len(hosts)
for i, host in enumerate(hosts):
if host == current_host:
start = shard_size * i
end = start + shard_size
break

train_data = train_data[start:end]

for epoch in range(epochs):
# reset data iterator and metric at begining of epoch.
metric.reset()
Expand Down
11 changes: 10 additions & 1 deletion sagemaker-python-sdk/mxnet_gluon_sentiment/sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,16 @@ def train(current_host, hosts, num_cpus, num_gpus, channel_input_dirs, model_dir
train_sentences = [[vocab.get(token, 1) for token in line if len(line)>0] for line in train_sentences]
val_sentences = [[vocab.get(token, 1) for token in line if len(line)>0] for line in val_sentences]

train_iterator = BucketSentenceIter(train_sentences, train_labels, batch_size)
# Alternatively to splitting in memory, the data could be pre-split in S3 and use ShardedByS3Key
# to do parallel training.
shard_size = len(train_sentences) // len(hosts)
for i, host in enumerate(hosts):
if host == current_host:
start = shard_size * i
end = start + shard_size
break

train_iterator = BucketSentenceIter(train_sentences[start:end], train_labels[start:end], batch_size)
val_iterator = BucketSentenceIter(val_sentences, val_labels, batch_size)

# define the network
Expand Down
14 changes: 12 additions & 2 deletions sagemaker-python-sdk/mxnet_mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,21 @@ def build_graph():
return mx.sym.SoftmaxOutput(data=fc3, name='softmax')


def train(channel_input_dirs, hyperparameters, hosts, num_cpus, num_gpus, **kwargs):
def train(current_host, channel_input_dirs, hyperparameters, hosts, num_cpus, num_gpus):
(train_labels, train_images) = load_data(os.path.join(channel_input_dirs['train']))
(test_labels, test_images) = load_data(os.path.join(channel_input_dirs['test']))

# Alternatively to splitting in memory, the data could be pre-split in S3 and use ShardedByS3Key
# to do parallel training.
shard_size = len(train_images) // len(hosts)
for i, host in enumerate(hosts):
if host == current_host:
start = shard_size * i
end = start + shard_size
break

batch_size = 100
train_iter = mx.io.NDArrayIter(train_images, train_labels, batch_size, shuffle=True)
train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size)
logging.getLogger().setLevel(logging.DEBUG)
kvstore = 'local' if len(hosts) == 1 else 'dist_sync'
Expand Down