Skip to content

Commit efb533e

Browse files
author
Hongshan Li
committed
Merge branch 'master' of github.com:aws/amazon-sagemaker-examples
2 parents 1f328b9 + 8c2d986 commit efb533e

File tree

20 files changed

+1475
-109
lines changed

20 files changed

+1475
-109
lines changed

frameworks/mxnet/code/config.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../config.json

frameworks/mxnet/code/inference.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# coding=utf-8
2+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from __future__ import print_function
17+
18+
import logging
19+
20+
from mxnet import gluon
21+
import mxnet as mx
22+
23+
import numpy as np
24+
import json
25+
import os
26+
27+
logging.basicConfig(level=logging.DEBUG)
28+
29+
def model_fn(model_dir):
30+
"""Load the gluon model. Called once when hosting service starts.
31+
32+
:param: model_dir The directory where model files are stored.
33+
:return: a model (in this case a Gluon network)
34+
"""
35+
net = gluon.SymbolBlock.imports(
36+
symbol_file=os.path.join(model_dir, 'model-symbol.json'),
37+
input_names=['data'],
38+
param_file=os.path.join(model_dir, 'model-0000.params'))
39+
return net
40+
41+
def transform_fn(net, data, input_content_type, output_content_type):
42+
assert input_content_type=='application/json'
43+
assert output_content_type=='application/json'
44+
45+
# parsed should be a 1d array of length 728
46+
parsed = json.loads(data)
47+
parsed = parsed['inputs']
48+
49+
# convert to numpy array
50+
arr = np.array(parsed).reshape(-1, 1, 28, 28)
51+
52+
# convert to mxnet ndarray
53+
nda = mx.nd.array(arr)
54+
55+
output = net(nda)
56+
57+
prediction = mx.nd.argmax(output, axis=1)
58+
response_body = json.dumps(prediction.asnumpy().tolist())
59+
60+
return response_body, output_content_type
61+
62+
63+
if __name__ == '__main__':
64+
model_dir = '/home/ubuntu/models/mxnet-gluon-mnist'
65+
net = model_fn(model_dir)
66+
67+
import json
68+
import random
69+
data = {'inputs': [random.random() for _ in range(784)]}
70+
data = json.dumps(data)
71+
72+
content_type = 'application/json'
73+
a, b = transform_fn(net, data, content_type, content_type)
74+
print(a, b)
75+
76+
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# coding=utf-8
2+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from inference import transform_fn, model_fn
17+
import os
18+
import json
19+
import shutil
20+
import boto3
21+
import botocore
22+
import tarfile
23+
import numpy as np
24+
import sagemaker
25+
26+
def fetch_model(model_data):
27+
""" Untar the model.tar.gz object either from local file system
28+
or a S3 location
29+
30+
Args:
31+
model_data (str): either a path to local file system starts with
32+
file:/// that points to the `model.tar.gz` file or an S3 link
33+
starts with s3:// that points to the `model.tar.gz` file
34+
35+
Returns:
36+
model_dir (str): the directory that contains the uncompress model
37+
checkpoint files
38+
"""
39+
40+
model_dir = "/tmp/model"
41+
if not os.path.exists(model_dir):
42+
os.makedirs(model_dir)
43+
44+
if model_data.startswith("file"):
45+
_check_model(model_data)
46+
shutil.copy2(os.path.join(model_dir, "model.tar.gz"),
47+
os.path.join(model_dir, "model.tar.gz"))
48+
elif model_data.startswith("s3"):
49+
# get bucket name and object key
50+
bucket_name = model_data.split("/")[2]
51+
key = "/".join(model_data.split("/")[3:])
52+
53+
s3 = boto3.resource("s3")
54+
try:
55+
s3.Bucket(bucket_name).download_file(
56+
key, os.path.join(model_dir, 'model.tar.gz'))
57+
except botocore.exceptions.ClientError as e:
58+
if e.response['Error']['Code'] == '404':
59+
print("the object does not exist.")
60+
else:
61+
raise
62+
63+
# untar the model
64+
tar = tarfile.open(os.path.join(model_dir, 'model.tar.gz'))
65+
tar.extractall(model_dir)
66+
tar.close()
67+
68+
return model_dir
69+
70+
71+
def test(model_dir):
72+
73+
# decompress the model.tar.gz file
74+
# model_dir = fetch_model(model_data)
75+
76+
# load the model
77+
net = model_fn(model_dir)
78+
79+
# simulate some input data to test transform_fn
80+
81+
data = {
82+
"inputs": np.random.rand(16, 1, 28, 28).tolist()
83+
}
84+
85+
# encode numpy array to binary stream
86+
serializer = sagemaker.serializers.JSONSerializer()
87+
88+
jstr = serializer.serialize(data)
89+
jstr = json.dumps(data)
90+
91+
# "send" the bin_stream to the endpoint for inference
92+
# inference container calls transform_fn to make an inference
93+
# and get the response body for the caller
94+
95+
res, content_type = transform_fn(net, jstr, "application/json",
96+
"application/json")
97+
print(res)
98+
return
99+
100+
101+
if __name__ == '__main__':
102+
model_dir='/tmp/ckpt'
103+
test(model_dir)
104+
105+

frameworks/mxnet/code/test_train.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# coding=utf-8
2+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from train import train, parse_args
17+
18+
import sys
19+
import os
20+
import boto3
21+
import json
22+
23+
dirname = os.path.dirname(os.path.abspath(__file__))
24+
with open(os.path.join(dirname, 'config.json'), "r") as f:
25+
CONFIG = json.load(f)
26+
27+
def download_from_s3(data_dir='/tmp/data', train=True):
28+
"""Download MNIST dataset and convert it to numpy array
29+
30+
Args:
31+
data_dir (str): directory to save the data
32+
train (bool): download training set
33+
34+
Returns:
35+
tuple of images and labels as numpy arrays
36+
"""
37+
38+
if not os.path.exists(data_dir):
39+
os.makedirs(data_dir)
40+
41+
42+
if train:
43+
images_file = "train-images-idx3-ubyte.gz"
44+
labels_file = "train-labels-idx1-ubyte.gz"
45+
else:
46+
images_file = "t10k-images-idx3-ubyte.gz"
47+
labels_file = "t10k-labels-idx1-ubyte.gz"
48+
49+
# download objects
50+
s3 = boto3.client('s3')
51+
bucket = CONFIG["public_bucket"]
52+
for obj in [images_file, labels_file]:
53+
key = os.path.join("datasets/image/MNIST", obj)
54+
dest = os.path.join(data_dir, obj)
55+
if not os.path.exists(dest):
56+
s3.download_file(bucket, key, dest)
57+
return
58+
59+
60+
class Env:
61+
def __init__(self):
62+
# simulate container env
63+
os.environ["SM_MODEL_DIR"] = "/tmp/model"
64+
os.environ["SM_CHANNEL_TRAINING"]="/tmp/data"
65+
os.environ["SM_CHANNEL_TESTING"]="/tmp/data"
66+
os.environ["SM_HOSTS"] = '["algo-1"]'
67+
os.environ["SM_CURRENT_HOST"]="algo-1"
68+
os.environ["SM_NUM_GPUS"] = "0"
69+
70+
71+
if __name__=='__main__':
72+
Env()
73+
args = parse_args()
74+
train(args)
75+

0 commit comments

Comments
 (0)