Skip to content

Add Mxnet MNIST example #1888

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions frameworks/mxnet/code/config.json
76 changes: 76 additions & 0 deletions frameworks/mxnet/code/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# coding=utf-8
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import logging

from mxnet import gluon
import mxnet as mx

import numpy as np
import json
import os

logging.basicConfig(level=logging.DEBUG)

def model_fn(model_dir):
"""Load the gluon model. Called once when hosting service starts.

:param: model_dir The directory where model files are stored.
:return: a model (in this case a Gluon network)
"""
net = gluon.SymbolBlock.imports(
symbol_file=os.path.join(model_dir, 'model-symbol.json'),
input_names=['data'],
param_file=os.path.join(model_dir, 'model-0000.params'))
return net

def transform_fn(net, data, input_content_type, output_content_type):
assert input_content_type=='application/json'
assert output_content_type=='application/json'

# parsed should be a 1d array of length 728
parsed = json.loads(data)
parsed = parsed['inputs']

# convert to numpy array
arr = np.array(parsed).reshape(-1, 1, 28, 28)

# convert to mxnet ndarray
nda = mx.nd.array(arr)

output = net(nda)

prediction = mx.nd.argmax(output, axis=1)
response_body = json.dumps(prediction.asnumpy().tolist())

return response_body, output_content_type


if __name__ == '__main__':
model_dir = '/home/ubuntu/models/mxnet-gluon-mnist'
net = model_fn(model_dir)

import json
import random
data = {'inputs': [random.random() for _ in range(784)]}
data = json.dumps(data)

content_type = 'application/json'
a, b = transform_fn(net, data, content_type, content_type)
print(a, b)


105 changes: 105 additions & 0 deletions frameworks/mxnet/code/test_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# coding=utf-8
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from inference import transform_fn, model_fn
import os
import json
import shutil
import boto3
import botocore
import tarfile
import numpy as np
import sagemaker

def fetch_model(model_data):
""" Untar the model.tar.gz object either from local file system
or a S3 location

Args:
model_data (str): either a path to local file system starts with
file:/// that points to the `model.tar.gz` file or an S3 link
starts with s3:// that points to the `model.tar.gz` file

Returns:
model_dir (str): the directory that contains the uncompress model
checkpoint files
"""

model_dir = "/tmp/model"
if not os.path.exists(model_dir):
os.makedirs(model_dir)

if model_data.startswith("file"):
_check_model(model_data)
shutil.copy2(os.path.join(model_dir, "model.tar.gz"),
os.path.join(model_dir, "model.tar.gz"))
elif model_data.startswith("s3"):
# get bucket name and object key
bucket_name = model_data.split("/")[2]
key = "/".join(model_data.split("/")[3:])

s3 = boto3.resource("s3")
try:
s3.Bucket(bucket_name).download_file(
key, os.path.join(model_dir, 'model.tar.gz'))
except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == '404':
print("the object does not exist.")
else:
raise

# untar the model
tar = tarfile.open(os.path.join(model_dir, 'model.tar.gz'))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.

Analysis of this code determined that this line of code contains a resource that might not have closed properly. Although there is an explicit call to close(), some execution paths do not reach it, such as in case of an exception. A resource leak can slow down or crash your system. Programs are strongly recommended to use the built in with keyword to open a resource or to use a try-finally block to open and close resources explicitly. The contextlib module provides helpful utilities for using the with statement.

Learn more

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

tar.extractall(model_dir)
tar.close()

return model_dir


def test(model_dir):

# decompress the model.tar.gz file
# model_dir = fetch_model(model_data)

# load the model
net = model_fn(model_dir)

# simulate some input data to test transform_fn

data = {
"inputs": np.random.rand(16, 1, 28, 28).tolist()
}

# encode numpy array to binary stream
serializer = sagemaker.serializers.JSONSerializer()

jstr = serializer.serialize(data)
jstr = json.dumps(data)

# "send" the bin_stream to the endpoint for inference
# inference container calls transform_fn to make an inference
# and get the response body for the caller

res, content_type = transform_fn(net, jstr, "application/json",
"application/json")
print(res)
return


if __name__ == '__main__':
model_dir='/tmp/ckpt'
test(model_dir)


75 changes: 75 additions & 0 deletions frameworks/mxnet/code/test_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# coding=utf-8
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from train import train, parse_args

import sys
import os
import boto3
import json

dirname = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(dirname, 'config.json'), "r") as f:
CONFIG = json.load(f)

def download_from_s3(data_dir='/tmp/data', train=True):
"""Download MNIST dataset and convert it to numpy array

Args:
data_dir (str): directory to save the data
train (bool): download training set

Returns:
tuple of images and labels as numpy arrays
"""

if not os.path.exists(data_dir):
os.makedirs(data_dir)


if train:
images_file = "train-images-idx3-ubyte.gz"
labels_file = "train-labels-idx1-ubyte.gz"
else:
images_file = "t10k-images-idx3-ubyte.gz"
labels_file = "t10k-labels-idx1-ubyte.gz"

# download objects
s3 = boto3.client('s3')
bucket = CONFIG["public_bucket"]
for obj in [images_file, labels_file]:
key = os.path.join("datasets/image/MNIST", obj)
dest = os.path.join(data_dir, obj)
if not os.path.exists(dest):
s3.download_file(bucket, key, dest)
return


class Env:
def __init__(self):
# simulate container env
os.environ["SM_MODEL_DIR"] = "/tmp/model"
os.environ["SM_CHANNEL_TRAINING"]="/tmp/data"
os.environ["SM_CHANNEL_TESTING"]="/tmp/data"
os.environ["SM_HOSTS"] = '["algo-1"]'
os.environ["SM_CURRENT_HOST"]="algo-1"
os.environ["SM_NUM_GPUS"] = "0"


if __name__=='__main__':
Env()
args = parse_args()
train(args)

Loading