Skip to content

Commit 329bfcf

Browse files
GaryTu1020chuyang-deng
authored andcommitted
change: git support testing (#861)
* add example files for git support testing
1 parent 3b26712 commit 329bfcf

File tree

17 files changed

+363
-141
lines changed

17 files changed

+363
-141
lines changed

.codecov.yml

Lines changed: 0 additions & 2 deletions
This file was deleted.

.coveragerc

Lines changed: 0 additions & 2 deletions
This file was deleted.

.flake8

Lines changed: 0 additions & 3 deletions
This file was deleted.

.gitignore

Lines changed: 0 additions & 28 deletions
This file was deleted.

.pylintrc

Lines changed: 0 additions & 90 deletions
This file was deleted.

.readthedocs.yml

Lines changed: 0 additions & 16 deletions
This file was deleted.

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# python-sdk-testing
2+
It's a repo for testing the sagemaker Python SDK Git support

alexa.py

Whitespace-only changes.

foo/bar.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# this is supposed to be a dependency.

foo/bar/a-file

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
This is a file.

foo/some-file

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
This is a file.

mxnet/mnist.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License"). You
2+
# may not use this file except in compliance with the License. A copy of
3+
# the License is located at
4+
#
5+
# http://aws.amazon.com/apache2.0/
6+
#
7+
# or in the "license" file accompanying this file. This file is
8+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
9+
# ANY KIND, either express or implied. See the License for the specific
10+
# language governing permissions and limitations under the License.
11+
from __future__ import absolute_import
12+
13+
import argparse
14+
import gzip
15+
import json
16+
import logging
17+
import os
18+
import struct
19+
import bar
20+
21+
import mxnet as mx
22+
import numpy as np
23+
24+
25+
def load_data(path):
26+
with gzip.open(find_file(path, "labels.gz")) as flbl:
27+
struct.unpack(">II", flbl.read(8))
28+
labels = np.fromstring(flbl.read(), dtype=np.int8)
29+
with gzip.open(find_file(path, "images.gz")) as fimg:
30+
_, _, rows, cols = struct.unpack(">IIII", fimg.read(16))
31+
images = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(labels), rows, cols)
32+
images = images.reshape(images.shape[0], 1, 28, 28).astype(np.float32) / 255
33+
return labels, images
34+
35+
36+
def find_file(root_path, file_name):
37+
for root, dirs, files in os.walk(root_path):
38+
if file_name in files:
39+
return os.path.join(root, file_name)
40+
41+
42+
def build_graph():
43+
data = mx.sym.var('data')
44+
data = mx.sym.flatten(data=data)
45+
fc1 = mx.sym.FullyConnected(data=data, num_hidden=128)
46+
act1 = mx.sym.Activation(data=fc1, act_type="relu")
47+
fc2 = mx.sym.FullyConnected(data=act1, num_hidden=64)
48+
act2 = mx.sym.Activation(data=fc2, act_type="relu")
49+
fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)
50+
return mx.sym.SoftmaxOutput(data=fc3, name='softmax')
51+
52+
53+
def get_train_context(num_gpus):
54+
if num_gpus:
55+
return [mx.gpu(i) for i in range(num_gpus)]
56+
else:
57+
return mx.cpu()
58+
59+
60+
def train(batch_size, epochs, learning_rate, num_gpus, training_channel, testing_channel,
61+
hosts, current_host, model_dir):
62+
(train_labels, train_images) = load_data(training_channel)
63+
(test_labels, test_images) = load_data(testing_channel)
64+
65+
# Data parallel training - shard the data so each host
66+
# only trains on a subset of the total data.
67+
shard_size = len(train_images) // len(hosts)
68+
for i, host in enumerate(hosts):
69+
if host == current_host:
70+
start = shard_size * i
71+
end = start + shard_size
72+
break
73+
74+
train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size,
75+
shuffle=True)
76+
val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size)
77+
78+
logging.getLogger().setLevel(logging.DEBUG)
79+
80+
kvstore = 'local' if len(hosts) == 1 else 'dist_sync'
81+
82+
mlp_model = mx.mod.Module(symbol=build_graph(),
83+
context=get_train_context(num_gpus))
84+
mlp_model.fit(train_iter,
85+
eval_data=val_iter,
86+
kvstore=kvstore,
87+
optimizer='sgd',
88+
optimizer_params={'learning_rate': learning_rate},
89+
eval_metric='acc',
90+
batch_end_callback=mx.callback.Speedometer(batch_size, 100),
91+
num_epoch=epochs)
92+
93+
if len(hosts) == 1 or current_host == hosts[0]:
94+
save(model_dir, mlp_model)
95+
96+
97+
def save(model_dir, model):
98+
model.symbol.save(os.path.join(model_dir, 'model-symbol.json'))
99+
model.save_params(os.path.join(model_dir, 'model-0000.params'))
100+
101+
signature = [{'name': data_desc.name, 'shape': [dim for dim in data_desc.shape]}
102+
for data_desc in model.data_shapes]
103+
with open(os.path.join(model_dir, 'model-shapes.json'), 'w') as f:
104+
json.dump(signature, f)
105+
106+
107+
if __name__ == '__main__':
108+
parser = argparse.ArgumentParser()
109+
110+
parser.add_argument('--batch-size', type=int, default=100)
111+
parser.add_argument('--epochs', type=int, default=10)
112+
parser.add_argument('--learning-rate', type=float, default=0.1)
113+
114+
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
115+
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
116+
parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST'])
117+
118+
parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST'])
119+
parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS']))
120+
121+
args = parser.parse_args()
122+
123+
num_gpus = int(os.environ['SM_NUM_GPUS'])
124+
125+
train(args.batch_size, args.epochs, args.learning_rate, num_gpus, args.train, args.test,
126+
args.hosts, args.current_host, args.model_dir)

mxnet/some_file

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
This is a file.

0 commit comments

Comments
 (0)