Skip to content

Commit ae15c9d

Browse files
GaryTu1020chuyang-deng
authored andcommitted
change: add serving script for mxnet (#882)
* add inference script
1 parent 329bfcf commit ae15c9d

File tree

3 files changed

+156
-1
lines changed

3 files changed

+156
-1
lines changed

.gitignore

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
.idea
2+
build
3+
src/*.egg-info
4+
.cache
5+
.coverage
6+
sagemaker_venv*
7+
*.egg-info
8+
.tox
9+
**/__pycache__
10+
**/.ipynb_checkpoints
11+
dist/
12+
**/tensorflow-examples.tar.gz
13+
**/*.pyc
14+
**.pyc
15+
scratch*.py
16+
.eggs
17+
*.egg
18+
examples/tensorflow/distributed_mnist/data
19+
*.iml
20+
doc/_build
21+
doc/_static
22+
doc/_templates
23+
**/.DS_Store
24+
venv/
25+
*~
26+
.pytest_cache/
27+
*.swp
28+
.docker/

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
# python-sdk-testing
1+
# test-branch-git-config
22
It's a repo for testing the sagemaker Python SDK Git support
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import gzip
16+
import json
17+
import mxnet as mx
18+
import numpy as np
19+
import os
20+
import struct
21+
22+
23+
# --- this example demonstrates how to extend default behavior during model hosting ---
24+
25+
# --- Model preparation ---
26+
# it is possible to specify own code to load the model, otherwise a default model loading takes place
27+
def model_fn(path_to_model_files):
28+
from mxnet.io import DataDesc
29+
30+
loaded_symbol = mx.symbol.load(os.path.join(path_to_model_files, "symbol"))
31+
created_module = mx.mod.Module(symbol=loaded_symbol)
32+
created_module.bind([DataDesc("data", (1, 1, 28, 28))])
33+
created_module.load_params(os.path.join(path_to_model_files, "params"))
34+
return created_module
35+
36+
37+
# --- Option 1 - provide just 1 entry point for end2end prediction ---
38+
# if this function is specified, no other overwriting described in Option 2 will have effect
39+
# returns serialized data and content type it has used
40+
def transform_fn(model, request_data, input_content_type, requested_output_content_type):
41+
# for demonstration purposes we will be calling handlers from Option2
42+
return (
43+
output_fn(
44+
process_request_fn(model, request_data, input_content_type),
45+
requested_output_content_type,
46+
),
47+
requested_output_content_type,
48+
)
49+
50+
51+
# --- Option 2 - overwrite container's default input/output behavior with handlers ---
52+
# there are 2 data handlers: input and output, you need to conform to their interface to fit into default execution
53+
def process_request_fn(model, data, input_content_type):
54+
if input_content_type == "text/s3_file_path":
55+
prediction_input = handle_s3_file_path(data)
56+
elif input_content_type == "application/json":
57+
prediction_input = handle_json_input(data)
58+
else:
59+
raise NotImplementedError(
60+
"This model doesnt support requested input type: " + input_content_type
61+
)
62+
63+
return model.predict(prediction_input)
64+
65+
66+
# for this example S3 path points to a file that is same format as in test/images.gz
67+
def handle_s3_file_path(path):
68+
import sys
69+
70+
if sys.version_info.major == 2:
71+
import urlparse
72+
73+
parse_cmd = urlparse.urlparse
74+
else:
75+
import urllib
76+
77+
parse_cmd = urllib.parse.urlparse
78+
79+
import boto3
80+
from botocore.exceptions import ClientError
81+
82+
# parse the path
83+
parsed_url = parse_cmd(path)
84+
85+
# get S3 client
86+
s3 = boto3.resource("s3")
87+
88+
# read file content and pass it down
89+
obj = s3.Object(parsed_url.netloc, parsed_url.path.lstrip("/"))
90+
print("loading file: " + str(obj))
91+
92+
try:
93+
data = obj.get()["Body"]
94+
except ClientError as ce:
95+
raise ValueError(
96+
"Can't download from S3 path: " + path + " : " + ce.response["Error"]["Message"]
97+
)
98+
99+
import StringIO
100+
101+
buf = StringIO(data.read())
102+
img = gzip.GzipFile(mode="rb", fileobj=buf)
103+
104+
_, _, rows, cols = struct.unpack(">IIII", img.read(16))
105+
images = np.fromstring(img.read(), dtype=np.uint8).reshape(10000, rows, cols)
106+
images = images.reshape(images.shape[0], 1, 28, 28).astype(np.float32) / 255
107+
108+
return mx.io.NDArrayIter(images, None, 1)
109+
110+
111+
# for this example it is assumed that the client is passing data that can be "directly" provided to the model
112+
def handle_json_input(data):
113+
nda = mx.nd.array(json.loads(data))
114+
return mx.io.NDArrayIter(nda, None, 1)
115+
116+
117+
def output_fn(prediction_output, requested_output_content_type):
118+
# output from the model is NDArray
119+
120+
data_to_return = prediction_output.asnumpy()
121+
122+
if requested_output_content_type == "application/json":
123+
json.dumps(data_to_return.tolist), requested_output_content_type
124+
125+
raise NotImplementedError(
126+
"Model doesn't support requested output type: " + requested_output_content_type
127+
)

0 commit comments

Comments
 (0)