Skip to content

Commit d3b37bb

Browse files
dlraghachuyang-deng
authored andcommitted
feature: support Multi-Model endpoints
* feature: support Multi-Model endpoints * fix: formatting for the Multi-Model endpoint files * fix: enable local file uploads with add_model, use model parameters during multi-model deploy * fix: patch multidatamodel unit test to work without default region env variable and add docker start cmd to buildspec * fix: add validation to MultiDataModel constructor, fix integ test to enable runs in parallel
1 parent 57af22c commit d3b37bb

File tree

10 files changed

+1313
-4
lines changed

10 files changed

+1313
-4
lines changed

buildspec.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
version: 0.2
22

33
phases:
4+
pre_build:
5+
commands:
6+
- start-dockerd
7+
48
build:
59
commands:
610
- IGNORE_COVERAGE=-

src/sagemaker/multidatamodel.py

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains code to create and manage SageMaker ``MultiDataModel``"""
14+
from __future__ import absolute_import
15+
16+
import os
17+
from six.moves.urllib.parse import urlparse
18+
19+
import sagemaker
20+
from sagemaker import s3
21+
from sagemaker.model import Model
22+
from sagemaker.session import Session
23+
24+
MULTI_MODEL_CONTAINER_MODE = "MultiModel"
25+
26+
27+
class MultiDataModel(Model):
28+
"""A SageMaker ``MultiDataModel`` that can be used to deploy multiple models to the same
29+
SageMaker ``Endpoint``, and also deploy additional models to an existing SageMaker
30+
multi-model ``Endpoint``
31+
"""
32+
33+
def __init__(
34+
self,
35+
name,
36+
model_data_prefix,
37+
model=None,
38+
image=None,
39+
role=None,
40+
sagemaker_session=None,
41+
**kwargs
42+
):
43+
"""Initialize a ``MultiDataModel``. In addition to these arguments, it supports all
44+
arguments supported by ``Model`` constructor
45+
46+
Args:
47+
name (str): The model name.
48+
model_data_prefix (str): The S3 prefix where all the models artifacts (.tar.gz)
49+
in a Multi-Model endpoint are located
50+
model (sagemaker.Model): The Model object that would define the
51+
SageMaker model attributes like vpc_config, predictors, etc.
52+
If this is present, the attributes from this model are used when
53+
deploying the ``MultiDataModel``. Parameters 'image', 'role' and 'kwargs'
54+
are not permitted when model parameter is set.
55+
image (str): A Docker image URI. It can be null if the 'model' parameter
56+
is passed to during ``MultiDataModel`` initialization (default: None)
57+
role (str): An AWS IAM role (either name or full ARN). The Amazon
58+
SageMaker training jobs and APIs that create Amazon SageMaker
59+
endpoints use this role to access training data and model
60+
artifacts. After the endpoint is created, the inference code
61+
might use the IAM role if it needs to access some AWS resources.
62+
It can be null if this is being used to create a Model to pass
63+
to a ``PipelineModel`` which has its own Role field or if the 'model' parameter
64+
is passed to during ``MultiDataModel`` initialization (default: None)
65+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
66+
object, used for SageMaker interactions (default: None). If not
67+
specified, one is created using the default AWS configuration
68+
chain.
69+
**kwargs: Keyword arguments passed to the ``Model`` initializer.
70+
"""
71+
# Validate path
72+
if not model_data_prefix.startswith("s3://"):
73+
raise ValueError(
74+
'Expecting S3 model prefix beginning with "s3://". Received: "{}"'.format(
75+
model_data_prefix
76+
)
77+
)
78+
79+
if model and (image or role or kwargs):
80+
raise ValueError(
81+
"Parameters image, role or kwargs are not permitted when model parameter is passed."
82+
)
83+
84+
self.name = name
85+
self.model_data_prefix = model_data_prefix
86+
self.model = model
87+
self.container_mode = MULTI_MODEL_CONTAINER_MODE
88+
self.sagemaker_session = sagemaker_session or Session()
89+
self.s3_client = self.sagemaker_session.boto_session.client("s3")
90+
91+
# Set the ``Model`` parameters if the model parameter is not specified
92+
if not self.model:
93+
super(MultiDataModel, self).__init__(
94+
self.model_data_prefix,
95+
image,
96+
role,
97+
name=self.name,
98+
sagemaker_session=self.sagemaker_session,
99+
**kwargs
100+
)
101+
102+
def prepare_container_def(self, instance_type, accelerator_type=None):
103+
"""Return a container definition set with MultiModel mode,
104+
model data and other parameters from the model (if available).
105+
106+
Subclasses can override this to provide custom container definitions
107+
for deployment to a specific instance type. Called by ``deploy()``.
108+
109+
Returns:
110+
dict[str, str]: A complete container definition object usable with the CreateModel API
111+
"""
112+
# Copy the trained model's image and environment variables if they exist. Models trained
113+
# with FrameworkEstimator set framework specific environment variables which need to be
114+
# copied over
115+
if self.model:
116+
container_definition = self.model.prepare_container_def(instance_type, accelerator_type)
117+
image = container_definition["Image"]
118+
environment = container_definition["Environment"]
119+
else:
120+
image = self.image
121+
environment = self.env
122+
return sagemaker.container_def(
123+
image,
124+
env=environment,
125+
model_data_url=self.model_data_prefix,
126+
container_mode=self.container_mode,
127+
)
128+
129+
def deploy(
130+
self,
131+
initial_instance_count,
132+
instance_type,
133+
accelerator_type=None,
134+
endpoint_name=None,
135+
update_endpoint=False,
136+
tags=None,
137+
kms_key=None,
138+
wait=True,
139+
data_capture_config=None,
140+
):
141+
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
142+
143+
Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an
144+
``Endpoint`` from this ``Model``. If self.model is not None, then the ``Endpoint``
145+
will be deployed with parameters in self.model (like vpc_config,
146+
enable_network_isolation, etc). If self.model is None, then use the parameters
147+
in ``MultiDataModel`` constructor will be used. If ``self.predictor_cls`` is not
148+
None, this method returns a the result of invoking ``self.predictor_cls`` on
149+
the created endpoint name.
150+
151+
The name of the created model is accessible in the ``name`` field of
152+
this ``Model`` after deploy returns
153+
154+
The name of the created endpoint is accessible in the
155+
``endpoint_name`` field of this ``Model`` after deploy returns.
156+
157+
Args:
158+
initial_instance_count (int): The initial number of instances to run
159+
in the ``Endpoint`` created from this ``Model``.
160+
instance_type (str): The EC2 instance type to deploy this Model to.
161+
For example, 'ml.p2.xlarge', or 'local' for local mode.
162+
accelerator_type (str): Type of Elastic Inference accelerator to
163+
deploy this model for model loading and inference, for example,
164+
'ml.eia1.medium'. If not specified, no Elastic Inference
165+
accelerator will be attached to the endpoint. For more
166+
information:
167+
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
168+
endpoint_name (str): The name of the endpoint to create (default:
169+
None). If not specified, a unique endpoint name will be created.
170+
update_endpoint (bool): Flag to update the model in an existing
171+
Amazon SageMaker endpoint. If True, this will deploy a new
172+
EndpointConfig to an already existing endpoint and delete
173+
resources corresponding to the previous EndpointConfig. If
174+
False, a new endpoint will be created. Default: False
175+
tags (List[dict[str, str]]): The list of tags to attach to this
176+
specific endpoint.
177+
kms_key (str): The ARN of the KMS key that is used to encrypt the
178+
data on the storage volume attached to the instance hosting the
179+
endpoint.
180+
wait (bool): Whether the call should wait until the deployment of
181+
this model completes (default: True).
182+
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
183+
configuration related to Endpoint data capture for use with
184+
Amazon SageMaker Model Monitoring. Default: None.
185+
186+
Returns:
187+
callable[string, sagemaker.session.Session] or None: Invocation of
188+
``self.predictor_cls`` on the created endpoint name,
189+
if ``self.predictor_cls``
190+
is not None. Otherwise, return None.
191+
"""
192+
# Set model specific parameters
193+
if self.model:
194+
enable_network_isolation = self.model.enable_network_isolation()
195+
role = self.model.role
196+
vpc_config = self.model.vpc_config
197+
predictor = self.model.predictor_cls
198+
else:
199+
enable_network_isolation = self.enable_network_isolation()
200+
role = self.role
201+
vpc_config = self.vpc_config
202+
predictor = self.predictor_cls
203+
204+
if role is None:
205+
raise ValueError("Role can not be null for deploying a model")
206+
207+
container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type)
208+
self.sagemaker_session.create_model(
209+
self.name,
210+
role,
211+
container_def,
212+
vpc_config=vpc_config,
213+
enable_network_isolation=enable_network_isolation,
214+
tags=tags,
215+
)
216+
217+
production_variant = sagemaker.production_variant(
218+
self.name, instance_type, initial_instance_count, accelerator_type=accelerator_type
219+
)
220+
if endpoint_name:
221+
self.endpoint_name = endpoint_name
222+
else:
223+
self.endpoint_name = self.name
224+
225+
data_capture_config_dict = None
226+
if data_capture_config is not None:
227+
data_capture_config_dict = data_capture_config._to_request_dict()
228+
229+
if update_endpoint:
230+
endpoint_config_name = self.sagemaker_session.create_endpoint_config(
231+
name=self.name,
232+
model_name=self.name,
233+
initial_instance_count=initial_instance_count,
234+
instance_type=instance_type,
235+
accelerator_type=accelerator_type,
236+
tags=tags,
237+
kms_key=kms_key,
238+
data_capture_config_dict=data_capture_config_dict,
239+
)
240+
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
241+
else:
242+
self.sagemaker_session.endpoint_from_production_variants(
243+
name=self.endpoint_name,
244+
production_variants=[production_variant],
245+
tags=tags,
246+
kms_key=kms_key,
247+
wait=wait,
248+
data_capture_config_dict=data_capture_config_dict,
249+
)
250+
251+
if predictor:
252+
return predictor(self.endpoint_name, self.sagemaker_session)
253+
return None
254+
255+
def add_model(self, model_data_source, model_data_path=None):
256+
"""Adds a model to the `MultiDataModel` by uploading or copying the model_data_source
257+
artifact to the given S3 path model_data_path relative to model_data_prefix
258+
259+
Args:
260+
model_source: Valid local file path or S3 path of the trained model artifact
261+
model_data_path: S3 path where the trained model artifact
262+
should be uploaded relative to `self.model_data_prefix` path. (default: None).
263+
If None, then the model artifact is uploaded to a path relative to model_data_prefix
264+
265+
Returns:
266+
str: S3 uri to uploaded model artifact
267+
"""
268+
parse_result = urlparse(model_data_source)
269+
270+
# If the model source is an S3 path, copy the model artifact to the destination S3 path
271+
if parse_result.scheme == "s3":
272+
source_bucket, source_model_data_path = s3.parse_s3_url(model_data_source)
273+
copy_source = {"Bucket": source_bucket, "Key": source_model_data_path}
274+
275+
if not model_data_path:
276+
model_data_path = source_model_data_path
277+
278+
# Construct the destination path
279+
dst_url = os.path.join(self.model_data_prefix, model_data_path)
280+
destination_bucket, destination_model_data_path = s3.parse_s3_url(dst_url)
281+
282+
# Copy the model artifact
283+
self.s3_client.copy(copy_source, destination_bucket, destination_model_data_path)
284+
return os.path.join("s3://", destination_bucket, destination_model_data_path)
285+
286+
# If the model source is a local path, upload the local model artifact to the destination
287+
# s3 path
288+
if os.path.exists(model_data_source):
289+
destination_bucket, dst_prefix = s3.parse_s3_url(self.model_data_prefix)
290+
if model_data_path:
291+
dst_s3_uri = os.path.join(dst_prefix, model_data_path)
292+
else:
293+
dst_s3_uri = os.path.join(dst_prefix, os.path.basename(model_data_source))
294+
self.s3_client.upload_file(model_data_source, destination_bucket, dst_s3_uri)
295+
# return upload_path
296+
return os.path.join("s3://", destination_bucket, dst_s3_uri)
297+
298+
# Raise error if the model source is of an unexpected type
299+
raise ValueError(
300+
"model_source must either be a valid local file path or s3 uri. Received: "
301+
'"{}"'.format(model_data_source)
302+
)
303+
304+
def list_models(self):
305+
"""Generates and returns relative paths to model archives stored at model_data_prefix
306+
S3 location.
307+
308+
Yields: Paths to model archives relative to model_data_prefix path.
309+
"""
310+
bucket, url_prefix = s3.parse_s3_url(self.model_data_prefix)
311+
file_keys = self.sagemaker_session.list_s3_files(bucket=bucket, key_prefix=url_prefix)
312+
for file_key in file_keys:
313+
# Return the model paths relative to the model_data_prefix
314+
# Ex: "a/b/c.tar.gz" -> "b/c.tar.gz" where url_prefix = "a/"
315+
yield file_key.replace(url_prefix, "")

src/sagemaker/predictor.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
self._endpoint_config_name = self._get_endpoint_config_name()
8484
self._model_names = self._get_model_names()
8585

86-
def predict(self, data, initial_args=None):
86+
def predict(self, data, initial_args=None, target_model=None):
8787
"""Return the inference from the specified endpoint.
8888
8989
Args:
@@ -95,6 +95,9 @@ def predict(self, data, initial_args=None):
9595
initial_args (dict[str,str]): Optional. Default arguments for boto3
9696
``invoke_endpoint`` call. Default is None (no default
9797
arguments).
98+
target_model (str): S3 model artifact path to run an inference request on,
99+
in case of a multi model endpoint. Does not apply to endpoints hosting
100+
single model (Default: None)
98101
99102
Returns:
100103
object: Inference for the given input. If a deserializer was specified when creating
@@ -103,7 +106,7 @@ def predict(self, data, initial_args=None):
103106
as is.
104107
"""
105108

106-
request_args = self._create_request_args(data, initial_args)
109+
request_args = self._create_request_args(data, initial_args, target_model)
107110
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
108111
return self._handle_response(response)
109112

@@ -120,11 +123,12 @@ def _handle_response(self, response):
120123
response_body.close()
121124
return data
122125

123-
def _create_request_args(self, data, initial_args=None):
126+
def _create_request_args(self, data, initial_args=None, target_model=None):
124127
"""
125128
Args:
126129
data:
127130
initial_args:
131+
target_model:
128132
"""
129133
args = dict(initial_args) if initial_args else {}
130134

@@ -137,6 +141,9 @@ def _create_request_args(self, data, initial_args=None):
137141
if self.accept and "Accept" not in args:
138142
args["Accept"] = self.accept
139143

144+
if target_model:
145+
args["TargetModel"] = target_model
146+
140147
if self.serializer is not None:
141148
data = self.serializer(data)
142149

0 commit comments

Comments
 (0)