Skip to content

Commit f34b128

Browse files
committed
update integ tests
1 parent 2008468 commit f34b128

File tree

4 files changed

+72
-22
lines changed

4 files changed

+72
-22
lines changed

tests/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from sagemaker.local import LocalSession
2323
from sagemaker.chainer.defaults import CHAINER_VERSION
2424
from sagemaker.pytorch.defaults import PYTORCH_VERSION
25-
from sagemaker.mxnet.defaults import MXNET_VERSION
2625
from sagemaker.tensorflow.defaults import TF_VERSION
2726

2827

@@ -34,7 +33,7 @@ def pytest_addoption(parser):
3433
parser.addoption('--sagemaker-runtime-config', action='store', default=None)
3534
parser.addoption('--boto-config', action='store', default=None)
3635
parser.addoption('--tf-full-version', action='store', default=TF_VERSION)
37-
parser.addoption('--mxnet-full-version', action='store', default=MXNET_VERSION)
36+
parser.addoption('--mxnet-full-version', action='store', default='1.3.0')
3837
parser.addoption('--chainer-full-version', action='store', default=CHAINER_VERSION)
3938
parser.addoption('--pytorch-full-version', action='store', default=PYTORCH_VERSION)
4039

@@ -86,7 +85,8 @@ def tf_version(request):
8685
return request.param
8786

8887

89-
@pytest.fixture(scope='module', params=['0.12', '0.12.1', '1.0', '1.0.0', '1.1', '1.1.0', '1.2', '1.2.1'])
88+
@pytest.fixture(scope='module', params=['0.12', '0.12.1', '1.0', '1.0.0', '1.1', '1.1.0', '1.2',
89+
'1.2.1', '1.3', '1.3.0'])
9090
def mxnet_version(request):
9191
return request.param
9292

tests/data/mxnet_mnist/failure_script.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
def train(**kwargs):
16-
"""For use with integration tests expecting failures."""
17-
raise Exception('This failure is expected.')
15+
16+
# For use with integration tests expecting failures.
17+
raise Exception('This failure is expected.')

tests/data/mxnet_mnist/mnist.py

Lines changed: 65 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2-
#
31
# Licensed under the Apache License, Version 2.0 (the "License"). You
42
# may not use this file except in compliance with the License. A copy of
53
# the License is located at
@@ -12,13 +10,17 @@
1210
# language governing permissions and limitations under the License.
1311
from __future__ import absolute_import
1412

13+
import argparse
14+
import gzip
15+
import json
1516
import logging
17+
import os
18+
import struct
1619

17-
import gzip
1820
import mxnet as mx
1921
import numpy as np
20-
import os
21-
import struct
22+
23+
from sagemaker_mxnet_container.training_utils import scheduler_host
2224

2325

2426
def load_data(path):
@@ -56,23 +58,70 @@ def get_train_context(num_gpus):
5658
return mx.cpu()
5759

5860

59-
def train(channel_input_dirs, hyperparameters, hosts, num_gpus, **kwargs):
60-
(train_labels, train_images) = load_data(os.path.join(channel_input_dirs['train']))
61-
(test_labels, test_images) = load_data(os.path.join(channel_input_dirs['test']))
62-
batch_size = 100
63-
train_iter = mx.io.NDArrayIter(train_images, train_labels, batch_size, shuffle=True)
61+
def train(batch_size, epochs, learning_rate, num_gpus, training_channel, testing_channel,
62+
hosts, current_host, model_dir):
63+
(train_labels, train_images) = load_data(training_channel)
64+
(test_labels, test_images) = load_data(testing_channel)
65+
66+
# Data parallel training - shard the data so each host
67+
# only trains on a subset of the total data.
68+
shard_size = len(train_images) // len(hosts)
69+
for i, host in enumerate(hosts):
70+
if host == current_host:
71+
start = shard_size * i
72+
end = start + shard_size
73+
break
74+
75+
train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size,
76+
shuffle=True)
6477
val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size)
78+
6579
logging.getLogger().setLevel(logging.DEBUG)
80+
6681
kvstore = 'local' if len(hosts) == 1 else 'dist_sync'
67-
mlp_model = mx.mod.Module(
68-
symbol=build_graph(),
69-
context=get_train_context(num_gpus))
82+
83+
mlp_model = mx.mod.Module(symbol=build_graph(),
84+
context=get_train_context(num_gpus))
7085
mlp_model.fit(train_iter,
7186
eval_data=val_iter,
7287
kvstore=kvstore,
7388
optimizer='sgd',
74-
optimizer_params={'learning_rate': float(hyperparameters.get("learning_rate", 0.1))},
89+
optimizer_params={'learning_rate': learning_rate},
7590
eval_metric='acc',
7691
batch_end_callback=mx.callback.Speedometer(batch_size, 100),
77-
num_epoch=1)
78-
return mlp_model
92+
num_epoch=epochs)
93+
94+
if len(hosts) == 1 or current_host == scheduler_host(hosts):
95+
save(model_dir, mlp_model)
96+
97+
98+
def save(model_dir, model):
99+
model.symbol.save(os.path.join(model_dir, 'model-symbol.json'))
100+
model.save_params(os.path.join(model_dir, 'model-0000.params'))
101+
102+
signature = [{'name': data_desc.name, 'shape': [dim for dim in data_desc.shape]}
103+
for data_desc in model.data_shapes]
104+
with open(os.path.join(model_dir, 'model-shapes.json'), 'w') as f:
105+
json.dump(signature, f)
106+
107+
108+
if __name__ == '__main__':
109+
parser = argparse.ArgumentParser()
110+
111+
parser.add_argument('--batch-size', type=int, default=100)
112+
parser.add_argument('--epochs', type=int, default=10)
113+
parser.add_argument('--learning-rate', type=float, default=0.1)
114+
115+
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
116+
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
117+
parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST'])
118+
119+
parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST'])
120+
parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS']))
121+
122+
args = parser.parse_args()
123+
124+
num_gpus = int(os.environ['SM_NUM_GPUS'])
125+
126+
train(args.batch_size, args.epochs, args.learning_rate, num_gpus, args.train, args.test,
127+
args.hosts, args.current_host, args.model_dir)

tests/integ/test_tuner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def test_tuning_mxnet(sagemaker_session):
191191
py_version=PYTHON_VERSION,
192192
train_instance_count=1,
193193
train_instance_type='ml.m4.xlarge',
194+
framework_version='1.2.1',
194195
sagemaker_session=sagemaker_session,
195196
base_job_name='tune-mxnet')
196197

0 commit comments

Comments
 (0)