-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Add Mxnet MNIST example #1888
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Mxnet MNIST example #1888
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../config.json |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
import logging | ||
|
||
from mxnet import gluon | ||
import mxnet as mx | ||
|
||
import numpy as np | ||
import json | ||
import os | ||
|
||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
def model_fn(model_dir): | ||
"""Load the gluon model. Called once when hosting service starts. | ||
|
||
:param: model_dir The directory where model files are stored. | ||
:return: a model (in this case a Gluon network) | ||
""" | ||
net = gluon.SymbolBlock.imports( | ||
symbol_file=os.path.join(model_dir, 'model-symbol.json'), | ||
input_names=['data'], | ||
param_file=os.path.join(model_dir, 'model-0000.params')) | ||
return net | ||
|
||
def transform_fn(net, data, input_content_type, output_content_type): | ||
assert input_content_type=='application/json' | ||
assert output_content_type=='application/json' | ||
|
||
# parsed should be a 1d array of length 728 | ||
parsed = json.loads(data) | ||
parsed = parsed['inputs'] | ||
|
||
# convert to numpy array | ||
arr = np.array(parsed).reshape(-1, 1, 28, 28) | ||
|
||
# convert to mxnet ndarray | ||
nda = mx.nd.array(arr) | ||
|
||
output = net(nda) | ||
|
||
prediction = mx.nd.argmax(output, axis=1) | ||
response_body = json.dumps(prediction.asnumpy().tolist()) | ||
|
||
return response_body, output_content_type | ||
|
||
|
||
if __name__ == '__main__': | ||
model_dir = '/home/ubuntu/models/mxnet-gluon-mnist' | ||
net = model_fn(model_dir) | ||
|
||
import json | ||
import random | ||
data = {'inputs': [random.random() for _ in range(784)]} | ||
data = json.dumps(data) | ||
|
||
content_type = 'application/json' | ||
a, b = transform_fn(net, data, content_type, content_type) | ||
print(a, b) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from inference import transform_fn, model_fn | ||
hongshanli23 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import os | ||
import json | ||
import shutil | ||
import boto3 | ||
import botocore | ||
import tarfile | ||
import numpy as np | ||
import sagemaker | ||
|
||
def fetch_model(model_data): | ||
""" Untar the model.tar.gz object either from local file system | ||
or a S3 location | ||
|
||
Args: | ||
model_data (str): either a path to local file system starts with | ||
file:/// that points to the `model.tar.gz` file or an S3 link | ||
starts with s3:// that points to the `model.tar.gz` file | ||
|
||
Returns: | ||
model_dir (str): the directory that contains the uncompress model | ||
checkpoint files | ||
""" | ||
|
||
model_dir = "/tmp/model" | ||
if not os.path.exists(model_dir): | ||
os.makedirs(model_dir) | ||
|
||
if model_data.startswith("file"): | ||
_check_model(model_data) | ||
shutil.copy2(os.path.join(model_dir, "model.tar.gz"), | ||
os.path.join(model_dir, "model.tar.gz")) | ||
elif model_data.startswith("s3"): | ||
# get bucket name and object key | ||
bucket_name = model_data.split("/")[2] | ||
key = "/".join(model_data.split("/")[3:]) | ||
|
||
s3 = boto3.resource("s3") | ||
try: | ||
s3.Bucket(bucket_name).download_file( | ||
key, os.path.join(model_dir, 'model.tar.gz')) | ||
except botocore.exceptions.ClientError as e: | ||
if e.response['Error']['Code'] == '404': | ||
print("the object does not exist.") | ||
else: | ||
raise | ||
|
||
# untar the model | ||
tar = tarfile.open(os.path.join(model_dir, 'model.tar.gz')) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji. Analysis of this code determined that this line of code contains a resource that might not have closed properly. Although there is an explicit call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
tar.extractall(model_dir) | ||
tar.close() | ||
|
||
return model_dir | ||
|
||
|
||
def test(model_dir): | ||
|
||
# decompress the model.tar.gz file | ||
# model_dir = fetch_model(model_data) | ||
|
||
# load the model | ||
net = model_fn(model_dir) | ||
|
||
# simulate some input data to test transform_fn | ||
|
||
data = { | ||
"inputs": np.random.rand(16, 1, 28, 28).tolist() | ||
} | ||
|
||
# encode numpy array to binary stream | ||
serializer = sagemaker.serializers.JSONSerializer() | ||
|
||
jstr = serializer.serialize(data) | ||
jstr = json.dumps(data) | ||
|
||
# "send" the bin_stream to the endpoint for inference | ||
# inference container calls transform_fn to make an inference | ||
# and get the response body for the caller | ||
|
||
res, content_type = transform_fn(net, jstr, "application/json", | ||
"application/json") | ||
print(res) | ||
return | ||
|
||
|
||
if __name__ == '__main__': | ||
model_dir='/tmp/ckpt' | ||
test(model_dir) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from train import train, parse_args | ||
hongshanli23 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
import sys | ||
import os | ||
import boto3 | ||
import json | ||
|
||
dirname = os.path.dirname(os.path.abspath(__file__)) | ||
with open(os.path.join(dirname, 'config.json'), "r") as f: | ||
CONFIG = json.load(f) | ||
|
||
def download_from_s3(data_dir='/tmp/data', train=True): | ||
"""Download MNIST dataset and convert it to numpy array | ||
|
||
Args: | ||
data_dir (str): directory to save the data | ||
train (bool): download training set | ||
|
||
Returns: | ||
tuple of images and labels as numpy arrays | ||
""" | ||
|
||
if not os.path.exists(data_dir): | ||
os.makedirs(data_dir) | ||
|
||
|
||
if train: | ||
images_file = "train-images-idx3-ubyte.gz" | ||
labels_file = "train-labels-idx1-ubyte.gz" | ||
else: | ||
images_file = "t10k-images-idx3-ubyte.gz" | ||
labels_file = "t10k-labels-idx1-ubyte.gz" | ||
|
||
# download objects | ||
s3 = boto3.client('s3') | ||
bucket = CONFIG["public_bucket"] | ||
for obj in [images_file, labels_file]: | ||
key = os.path.join("datasets/image/MNIST", obj) | ||
dest = os.path.join(data_dir, obj) | ||
if not os.path.exists(dest): | ||
s3.download_file(bucket, key, dest) | ||
return | ||
|
||
|
||
class Env: | ||
def __init__(self): | ||
# simulate container env | ||
os.environ["SM_MODEL_DIR"] = "/tmp/model" | ||
os.environ["SM_CHANNEL_TRAINING"]="/tmp/data" | ||
os.environ["SM_CHANNEL_TESTING"]="/tmp/data" | ||
os.environ["SM_HOSTS"] = '["algo-1"]' | ||
os.environ["SM_CURRENT_HOST"]="algo-1" | ||
os.environ["SM_NUM_GPUS"] = "0" | ||
|
||
|
||
if __name__=='__main__': | ||
Env() | ||
args = parse_args() | ||
train(args) | ||
|
Uh oh!
There was an error while loading. Please reload this page.