|
1 |
| -# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
2 |
| -# |
3 | 1 | # Licensed under the Apache License, Version 2.0 (the "License"). You
|
4 | 2 | # may not use this file except in compliance with the License. A copy of
|
5 | 3 | # the License is located at
|
|
12 | 10 | # language governing permissions and limitations under the License.
|
13 | 11 | from __future__ import absolute_import
|
14 | 12 |
|
| 13 | +import argparse |
| 14 | +import gzip |
| 15 | +import json |
15 | 16 | import logging
|
| 17 | +import os |
| 18 | +import struct |
16 | 19 |
|
17 |
| -import gzip |
18 | 20 | import mxnet as mx
|
19 | 21 | import numpy as np
|
20 |
| -import os |
21 |
| -import struct |
| 22 | + |
| 23 | +from sagemaker_mxnet_container.training_utils import scheduler_host |
22 | 24 |
|
23 | 25 |
|
24 | 26 | def load_data(path):
|
@@ -56,23 +58,70 @@ def get_train_context(num_gpus):
|
56 | 58 | return mx.cpu()
|
57 | 59 |
|
58 | 60 |
|
59 |
| -def train(channel_input_dirs, hyperparameters, hosts, num_gpus, **kwargs): |
60 |
| - (train_labels, train_images) = load_data(os.path.join(channel_input_dirs['train'])) |
61 |
| - (test_labels, test_images) = load_data(os.path.join(channel_input_dirs['test'])) |
62 |
| - batch_size = 100 |
63 |
| - train_iter = mx.io.NDArrayIter(train_images, train_labels, batch_size, shuffle=True) |
| 61 | +def train(batch_size, epochs, learning_rate, num_gpus, training_channel, testing_channel, |
| 62 | + hosts, current_host, model_dir): |
| 63 | + (train_labels, train_images) = load_data(training_channel) |
| 64 | + (test_labels, test_images) = load_data(testing_channel) |
| 65 | + |
| 66 | + # Data parallel training - shard the data so each host |
| 67 | + # only trains on a subset of the total data. |
| 68 | + shard_size = len(train_images) // len(hosts) |
| 69 | + for i, host in enumerate(hosts): |
| 70 | + if host == current_host: |
| 71 | + start = shard_size * i |
| 72 | + end = start + shard_size |
| 73 | + break |
| 74 | + |
| 75 | + train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size, |
| 76 | + shuffle=True) |
64 | 77 | val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size)
|
| 78 | + |
65 | 79 | logging.getLogger().setLevel(logging.DEBUG)
|
| 80 | + |
66 | 81 | kvstore = 'local' if len(hosts) == 1 else 'dist_sync'
|
67 |
| - mlp_model = mx.mod.Module( |
68 |
| - symbol=build_graph(), |
69 |
| - context=get_train_context(num_gpus)) |
| 82 | + |
| 83 | + mlp_model = mx.mod.Module(symbol=build_graph(), |
| 84 | + context=get_train_context(num_gpus)) |
70 | 85 | mlp_model.fit(train_iter,
|
71 | 86 | eval_data=val_iter,
|
72 | 87 | kvstore=kvstore,
|
73 | 88 | optimizer='sgd',
|
74 |
| - optimizer_params={'learning_rate': float(hyperparameters.get("learning_rate", 0.1))}, |
| 89 | + optimizer_params={'learning_rate': learning_rate}, |
75 | 90 | eval_metric='acc',
|
76 | 91 | batch_end_callback=mx.callback.Speedometer(batch_size, 100),
|
77 |
| - num_epoch=1) |
78 |
| - return mlp_model |
| 92 | + num_epoch=epochs) |
| 93 | + |
| 94 | + if len(hosts) == 1 or current_host == scheduler_host(hosts): |
| 95 | + save(model_dir, mlp_model) |
| 96 | + |
| 97 | + |
| 98 | +def save(model_dir, model): |
| 99 | + model.symbol.save(os.path.join(model_dir, 'model-symbol.json')) |
| 100 | + model.save_params(os.path.join(model_dir, 'model-0000.params')) |
| 101 | + |
| 102 | + signature = [{'name': data_desc.name, 'shape': [dim for dim in data_desc.shape]} |
| 103 | + for data_desc in model.data_shapes] |
| 104 | + with open(os.path.join(model_dir, 'model-shapes.json'), 'w') as f: |
| 105 | + json.dump(signature, f) |
| 106 | + |
| 107 | + |
| 108 | +if __name__ == '__main__': |
| 109 | + parser = argparse.ArgumentParser() |
| 110 | + |
| 111 | + parser.add_argument('--batch-size', type=int, default=100) |
| 112 | + parser.add_argument('--epochs', type=int, default=10) |
| 113 | + parser.add_argument('--learning-rate', type=float, default=0.1) |
| 114 | + |
| 115 | + parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) |
| 116 | + parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN']) |
| 117 | + parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST']) |
| 118 | + |
| 119 | + parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST']) |
| 120 | + parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS'])) |
| 121 | + |
| 122 | + args = parser.parse_args() |
| 123 | + |
| 124 | + num_gpus = int(os.environ['SM_NUM_GPUS']) |
| 125 | + |
| 126 | + train(args.batch_size, args.epochs, args.learning_rate, num_gpus, args.train, args.test, |
| 127 | + args.hosts, args.current_host, args.model_dir) |
0 commit comments