Skip to content

Commit 8c62b99

Browse files
fm1ch4abnagara
authored andcommitted
Add sample notebooks for multi-model endpoints functionality (#935)
* Add sample notebooks for multi-model endpoints functionality.
1 parent 240ca21 commit 8c62b99

File tree

7 files changed

+2534
-0
lines changed

7 files changed

+2534
-0
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ These examples that showcase unique functionality available in Amazon SageMaker.
109109
- [Inference Pipeline with SparkML and XGBoost](advanced_functionality/inference_pipeline_sparkml_xgboost_abalone) shows how to deploy an Inference Pipeline with SparkML for data pre-processing and XGBoost for training on the Abalone dataset. The pre-processing code is written once and used between training and inference.
110110
- [Inference Pipeline with SparkML and BlazingText](advanced_functionality/inference_pipeline_sparkml_blazingtext_dbpedia) shows how to deploy an Inference Pipeline with SparkML for data pre-processing and BlazingText for training on the DBPedia dataset. The pre-processing code is written once and used between training and inference.
111111
- [Experiment Management Capabilities with Search](advanced_functionality/search) shows how to organize Training Jobs into projects, and track relationships between Models, Endpoints, and Training Jobs.
112+
- [Host Multiple Models with Your Own Algorithm](advanced_functionality/multi_model_bring_your_own) shows how to deploy multiple models to a realtime hosted endpoint with your own custom algorithm.
113+
- [Host Multiple Models with XGBoost](advanced_functionality/multi_model_xgboost_home_value) shows how to deploy multiple models to a realtime hosted endpoint using a multi-model enabled XGBoost container.
114+
- [Host Multiple Models with SKLearn](advanced_functionality/multi_model_sklearn_home_value) shows how to deploy multiple models to a realtime hosted endpoint using a multi-model enabled SKLearn container.
112115

113116
### Amazon SageMaker Neo Compilation Jobs
114117

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
FROM ubuntu:16.04
2+
3+
# Set a docker label to advertise multi-model support on the container
4+
LABEL com.amazonaws.sagemaker.capabilities.multi-models=true
5+
# Set a docker label to enable container to use SAGEMAKER_BIND_TO_PORT environment variable if present
6+
LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true
7+
8+
# Install necessary dependencies for MMS and SageMaker Inference Toolkit
9+
RUN apt-get update && \
10+
apt-get -y install --no-install-recommends \
11+
build-essential \
12+
ca-certificates \
13+
openjdk-8-jdk-headless \
14+
python3-dev \
15+
curl \
16+
vim \
17+
&& rm -rf /var/lib/apt/lists/* \
18+
&& curl -O https://bootstrap.pypa.io/get-pip.py \
19+
&& python3 get-pip.py
20+
21+
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1
22+
RUN update-alternatives --install /usr/local/bin/pip pip /usr/local/bin/pip3 1
23+
24+
# Install MXNet, MMS, and SageMaker Inference Toolkit to set up MMS
25+
RUN pip3 --no-cache-dir install mxnet \
26+
multi-model-server \
27+
sagemaker-inference \
28+
retrying
29+
30+
# Copy entrypoint script to the image
31+
COPY dockerd-entrypoint.py /usr/local/bin/dockerd-entrypoint.py
32+
RUN chmod +x /usr/local/bin/dockerd-entrypoint.py
33+
34+
RUN mkdir -p /home/model-server/
35+
36+
# Copy the default custom service file to handle incoming data and inference requests
37+
COPY model_handler.py /home/model-server/model_handler.py
38+
39+
# Define an entrypoint script for the docker image
40+
ENTRYPOINT ["python", "/usr/local/bin/dockerd-entrypoint.py"]
41+
42+
# Define command to be passed to the entrypoint
43+
CMD ["serve"]
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import subprocess
2+
import sys
3+
import shlex
4+
import os
5+
from retrying import retry
6+
from subprocess import CalledProcessError
7+
from sagemaker_inference import model_server
8+
9+
def _retry_if_error(exception):
10+
return isinstance(exception, CalledProcessError or OSError)
11+
12+
@retry(stop_max_delay=1000 * 50,
13+
retry_on_exception=_retry_if_error)
14+
def _start_mms():
15+
# by default the number of workers per model is 1, but we can configure it through the
16+
# environment variable below if desired.
17+
# os.environ['SAGEMAKER_MODEL_SERVER_WORKERS'] = '2'
18+
model_server.start_model_server(handler_service='/home/model-server/model_handler.py:handle')
19+
20+
def main():
21+
if sys.argv[1] == 'serve':
22+
_start_mms()
23+
else:
24+
subprocess.check_call(shlex.split(' '.join(sys.argv[1:])))
25+
26+
# prevent docker exit
27+
subprocess.call(['tail', '-f', '/dev/null'])
28+
29+
main()
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
"""
2+
ModelHandler defines an example model handler for load and inference requests for MXNet CPU models
3+
"""
4+
from collections import namedtuple
5+
import glob
6+
import json
7+
import logging
8+
import os
9+
import re
10+
11+
import mxnet as mx
12+
import numpy as np
13+
14+
class ModelHandler(object):
15+
"""
16+
A sample Model handler implementation.
17+
"""
18+
19+
def __init__(self):
20+
self.initialized = False
21+
self.mx_model = None
22+
self.shapes = None
23+
24+
def get_model_files_prefix(self, model_dir):
25+
"""
26+
Get the model prefix name for the model artifacts (symbol and parameter file).
27+
This assume model artifact directory contains a symbol file, parameter file,
28+
model shapes file and a synset file defining the labels
29+
30+
:param model_dir: Path to the directory with model artifacts
31+
:return: prefix string for model artifact files
32+
"""
33+
sym_file_suffix = "-symbol.json"
34+
checkpoint_prefix_regex = "{}/*{}".format(model_dir, sym_file_suffix) # Ex output: /opt/ml/models/resnet-18/model/*-symbol.json
35+
checkpoint_prefix_filename = glob.glob(checkpoint_prefix_regex)[0] # Ex output: /opt/ml/models/resnet-18/model/resnet18-symbol.json
36+
checkpoint_prefix = os.path.basename(checkpoint_prefix_filename).split(sym_file_suffix)[0] # Ex output: resnet18
37+
logging.info("Prefix for the model artifacts: {}".format(checkpoint_prefix))
38+
return checkpoint_prefix
39+
40+
def get_input_data_shapes(self, model_dir, checkpoint_prefix):
41+
"""
42+
Get the model input data shapes and return the list
43+
44+
:param model_dir: Path to the directory with model artifacts
45+
:param checkpoint_prefix: Model files prefix name
46+
:return: prefix string for model artifact files
47+
"""
48+
shapes_file_path = os.path.join(model_dir, "{}-{}".format(checkpoint_prefix, "shapes.json"))
49+
if not os.path.isfile(shapes_file_path):
50+
raise RuntimeError("Missing {} file.".format(shapes_file_path))
51+
52+
with open(shapes_file_path) as f:
53+
self.shapes = json.load(f)
54+
55+
data_shapes = []
56+
57+
for input_data in self.shapes:
58+
data_name = input_data["name"]
59+
data_shape = input_data["shape"]
60+
data_shapes.append((data_name, tuple(data_shape)))
61+
62+
return data_shapes
63+
64+
def initialize(self, context):
65+
"""
66+
Initialize model. This will be called during model loading time
67+
:param context: Initial context contains model server system properties.
68+
:return:
69+
"""
70+
self.initialized = True
71+
properties = context.system_properties
72+
# Contains the url parameter passed to the load request
73+
model_dir = properties.get("model_dir")
74+
gpu_id = properties.get("gpu_id")
75+
76+
checkpoint_prefix = self.get_model_files_prefix(model_dir)
77+
78+
# Read the model input data shapes
79+
data_shapes = self.get_input_data_shapes(model_dir, checkpoint_prefix)
80+
81+
# Load MXNet model
82+
try:
83+
ctx = mx.cpu() # Set the context on CPU
84+
sym, arg_params, aux_params = mx.model.load_checkpoint(checkpoint_prefix, 0) # epoch set to 0
85+
self.mx_model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
86+
self.mx_model.bind(for_training=False, data_shapes=data_shapes,
87+
label_shapes=self.mx_model._label_shapes)
88+
self.mx_model.set_params(arg_params, aux_params, allow_missing=True)
89+
with open("synset.txt", 'r') as f:
90+
self.labels = [l.rstrip() for l in f]
91+
except (mx.base.MXNetError, RuntimeError) as memerr:
92+
if re.search('Failed to allocate (.*) Memory', str(memerr), re.IGNORECASE):
93+
logging.error("Memory allocation exception: {}".format(memerr))
94+
raise MemoryError
95+
raise
96+
97+
def preprocess(self, request):
98+
"""
99+
Transform raw input into model input data.
100+
:param request: list of raw requests
101+
:return: list of preprocessed model input data
102+
"""
103+
# Take the input data and pre-process it make it inference ready
104+
105+
img_list = []
106+
for idx, data in enumerate(request):
107+
# Read the bytearray of the image from the input
108+
img_arr = data.get('body')
109+
110+
# Input image is in bytearray, convert it to MXNet NDArray
111+
img = mx.img.imdecode(img_arr)
112+
if img is None:
113+
return None
114+
115+
# convert into format (batch, RGB, width, height)
116+
img = mx.image.imresize(img, 224, 224) # resize
117+
img = img.transpose((2, 0, 1)) # Channel first
118+
img = img.expand_dims(axis=0) # batchify
119+
img_list.append(img)
120+
121+
return img_list
122+
123+
def inference(self, model_input):
124+
"""
125+
Internal inference methods
126+
:param model_input: transformed model input data list
127+
:return: list of inference output in NDArray
128+
"""
129+
# Do some inference call to engine here and return output
130+
Batch = namedtuple('Batch', ['data'])
131+
self.mx_model.forward(Batch(model_input))
132+
prob = self.mx_model.get_outputs()[0].asnumpy()
133+
return prob
134+
135+
def postprocess(self, inference_output):
136+
"""
137+
Return predict result in as list.
138+
:param inference_output: list of inference output
139+
:return: list of predict results
140+
"""
141+
# Take output from network and post-process to desired format
142+
prob = np.squeeze(inference_output)
143+
a = np.argsort(prob)[::-1]
144+
return [['probability=%f, class=%s' %(prob[i], self.labels[i]) for i in a[0:5]]]
145+
146+
def handle(self, data, context):
147+
"""
148+
Call preprocess, inference and post-process functions
149+
:param data: input data
150+
:param context: mms context
151+
"""
152+
153+
model_input = self.preprocess(data)
154+
model_out = self.inference(model_input)
155+
return self.postprocess(model_out)
156+
157+
_service = ModelHandler()
158+
159+
160+
def handle(data, context):
161+
if not _service.initialized:
162+
_service.initialize(context)
163+
164+
if data is None:
165+
return None
166+
167+
return _service.handle(data, context)

0 commit comments

Comments
 (0)