Skip to content

Commit b6c9b0c

Browse files
authored
add tfs container support (#460)
* add tensorflow serving container support
1 parent 163bffd commit b6c9b0c

File tree

7 files changed

+517
-36
lines changed

7 files changed

+517
-36
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
CHANGELOG
33
=========
44

5-
1.13.1.dev
5+
1.14.0-dev
66
==========
77

8+
* feature: add support for sagemaker-tensorflow-serving container
89
* feature: Estimator: make input channels optional
910

10-
1111
1.13.0
1212
======
1313

src/sagemaker/predictor.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,35 +57,28 @@ def __init__(self, endpoint, sagemaker_session=None, serializer=None, deserializ
5757
self.content_type = content_type or getattr(serializer, 'content_type', None)
5858
self.accept = accept or getattr(deserializer, 'accept', None)
5959

60-
def predict(self, data):
60+
def predict(self, data, initial_args=None):
6161
"""Return the inference from the specified endpoint.
6262
6363
Args:
6464
data (object): Input data for which you want the model to provide inference.
6565
If a serializer was specified when creating the RealTimePredictor, the result of the
6666
serializer is sent as input data. Otherwise the data must be sequence of bytes, and
6767
the predict method then sends the bytes in the request body as is.
68+
initial_args (dict[str,str]): Optional. Default arguments for boto3
69+
``invoke_endpoint`` call. Default is None (no default arguments).
6870
6971
Returns:
7072
object: Inference for the given input. If a deserializer was specified when creating
7173
the RealTimePredictor, the result of the deserializer is returned. Otherwise the response
7274
returns the sequence of bytes as is.
7375
"""
74-
if self.serializer is not None:
75-
data = self.serializer(data)
76-
77-
request_args = {
78-
'EndpointName': self.endpoint,
79-
'Body': data
80-
}
81-
82-
if self.content_type:
83-
request_args['ContentType'] = self.content_type
84-
if self.accept:
85-
request_args['Accept'] = self.accept
8676

77+
request_args = self._create_request_args(data, initial_args)
8778
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
79+
return self._handle_response(response)
8880

81+
def _handle_response(self, response):
8982
response_body = response['Body']
9083
if self.deserializer is not None:
9184
# It's the deserializer's responsibility to close the stream
@@ -94,6 +87,24 @@ def predict(self, data):
9487
response_body.close()
9588
return data
9689

90+
def _create_request_args(self, data, initial_args=None):
91+
args = dict(initial_args) if initial_args else {}
92+
93+
if 'EndpointName' not in args:
94+
args['EndpointName'] = self.endpoint
95+
96+
if self.content_type and 'ContentType' not in args:
97+
args['ContentType'] = self.content_type
98+
99+
if self.accept and 'Accept' not in args:
100+
args['Accept'] = self.accept
101+
102+
if self.serializer is not None:
103+
data = self.serializer(data)
104+
105+
args['Body'] = data
106+
return args
107+
97108
def delete_endpoint(self):
98109
"""Delete the Amazon SageMaker endpoint backing this predictor.
99110
"""

src/sagemaker/tensorflow/estimator.py

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
import time
2323

2424
from sagemaker.estimator import Framework
25-
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning
26-
from sagemaker.utils import get_config_value
27-
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28-
25+
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, \
26+
empty_framework_version_warning
2927
from sagemaker.tensorflow.defaults import TF_VERSION
3028
from sagemaker.tensorflow.model import TensorFlowModel
29+
from sagemaker.tensorflow.serving import Model
30+
from sagemaker.utils import get_config_value
31+
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
3132

3233
logging.basicConfig()
3334
LOGGER = logging.getLogger('sagemaker')
@@ -103,12 +104,14 @@ def validate_requirements(self):
103104
EnvironmentError: If at least one requirement is not installed.
104105
"""
105106
if not self._cmd_exists('tensorboard'):
106-
raise EnvironmentError('TensorBoard is not installed in the system. Please install TensorBoard using the'
107-
' following command: \n pip install tensorboard')
107+
raise EnvironmentError(
108+
'TensorBoard is not installed in the system. Please install TensorBoard using the'
109+
' following command: \n pip install tensorboard')
108110

109111
if not self._cmd_exists('aws'):
110-
raise EnvironmentError('The AWS CLI is not installed in the system. Please install the AWS CLI using the'
111-
' following command: \n pip install awscli')
112+
raise EnvironmentError(
113+
'The AWS CLI is not installed in the system. Please install the AWS CLI using the'
114+
' following command: \n pip install awscli')
112115

113116
def create_tensorboard_process(self):
114117
"""Create a TensorBoard process.
@@ -125,7 +128,8 @@ def create_tensorboard_process(self):
125128

126129
for i in range(100):
127130
p = subprocess.Popen(
128-
["tensorboard", "--logdir", self.logdir, "--host", "localhost", "--port", str(port)],
131+
["tensorboard", "--logdir", self.logdir, "--host", "localhost", "--port",
132+
str(port)],
129133
stdout=subprocess.PIPE,
130134
stderr=subprocess.PIPE
131135
)
@@ -135,7 +139,8 @@ def create_tensorboard_process(self):
135139
else:
136140
return port, p
137141

138-
raise OSError('No available ports to start TensorBoard. Attempted all ports between 6006 and 6105')
142+
raise OSError(
143+
'No available ports to start TensorBoard. Attempted all ports between 6006 and 6105')
139144

140145
def run(self):
141146
"""Run TensorBoard process."""
@@ -158,7 +163,8 @@ class TensorFlow(Framework):
158163

159164
__framework_name__ = 'tensorflow'
160165

161-
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2',
166+
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None,
167+
py_version='py2',
162168
framework_version=None, requirements_file='', image_name=None, **kwargs):
163169
"""Initialize an ``TensorFlow`` estimator.
164170
Args:
@@ -202,7 +208,8 @@ def _validate_requirements_file(self, requirements_file):
202208
raise ValueError('Must specify source_dir along with a requirements file.')
203209

204210
if os.path.isabs(requirements_file):
205-
raise ValueError('Requirements file {} is not a path relative to source_dir.'.format(requirements_file))
211+
raise ValueError('Requirements file {} is not a path relative to source_dir.'.format(
212+
requirements_file))
206213

207214
if not os.path.exists(os.path.join(self.source_dir, requirements_file)):
208215
raise ValueError('Requirements file {} does not exist.'.format(requirements_file))
@@ -231,6 +238,7 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None, run_tensorboard_
231238
downloaded checkpoint information (default: False). This is an experimental feature, and requires
232239
TensorBoard and AWS CLI to be installed. It terminates TensorBoard when execution ends.
233240
"""
241+
234242
def fit_super():
235243
super(TensorFlow, self).fit(inputs, wait, logs, job_name)
236244

@@ -263,7 +271,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
263271
dictionary: The transformed init_params
264272
265273
"""
266-
init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description(job_details, model_channel_name)
274+
init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description(job_details,
275+
model_channel_name)
267276

268277
# Move some of the tensorflow specific init params from hyperparameters into the main init params.
269278
for argument in ['checkpoint_path', 'training_steps', 'evaluation_steps']:
@@ -285,15 +294,18 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
285294
# containing framework version, device type and python version (e.g. '1.5-gpu-py2').
286295
# For backward compatibility map deprecated image tag '1.0' to a '1.4' framework version
287296
# otherwise extract framework version from the tag itself.
288-
init_params['framework_version'] = '1.4' if tag == '1.0' else framework_version_from_tag(tag)
297+
init_params['framework_version'] = '1.4' if tag == '1.0' else framework_version_from_tag(
298+
tag)
289299

290300
training_job_name = init_params['base_job_name']
291301
if framework != cls.__framework_name__:
292-
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))
302+
raise ValueError("Training job: {} didn't use image for requested framework".format(
303+
training_job_name))
293304

294305
return init_params
295306

296-
def create_model(self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
307+
def create_model(self, model_server_workers=None, role=None,
308+
vpc_config_override=VPC_CONFIG_DEFAULT, endpoint_type=None):
297309
"""Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``.
298310
299311
Args:
@@ -305,18 +317,44 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override
305317
Default: use subnets and security groups from this Estimator.
306318
* 'Subnets' (list[str]): List of subnet ids.
307319
* 'SecurityGroupIds' (list[str]): List of security group ids.
320+
endpoint_type: Optional. Selects the software stack used by the inference server.
321+
If not specified, the model will be configured to use the default
322+
SageMaker model server. If 'tensorflow-serving', the model will be configured to
323+
use the SageMaker Tensorflow Serving container.
308324
309325
Returns:
310326
sagemaker.tensorflow.model.TensorFlowModel: A SageMaker ``TensorFlowModel`` object.
311327
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
312328
"""
313-
env = {'SAGEMAKER_REQUIREMENTS': self.requirements_file}
329+
314330
role = role or self.role
315-
return TensorFlowModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
316-
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, env=env, image=self.image_name,
317-
name=self._current_job_name, container_log_level=self.container_log_level,
331+
if endpoint_type == 'tensorflow-serving':
332+
return self._create_tfs_model(role=role, vpc_config_override=vpc_config_override)
333+
334+
return self._create_default_model(model_server_workers=model_server_workers, role=role,
335+
vpc_config_override=vpc_config_override)
336+
337+
def _create_tfs_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
338+
return Model(model_data=self.model_data,
339+
role=role,
340+
image=self.image_name,
341+
name=self._current_job_name,
342+
container_log_level=self.container_log_level,
343+
framework_version=self.framework_version,
344+
sagemaker_session=self.sagemaker_session,
345+
vpc_config=self.get_vpc_config(vpc_config_override))
346+
347+
def _create_default_model(self, model_server_workers, role, vpc_config_override):
348+
return TensorFlowModel(self.model_data, role, self.entry_point,
349+
source_dir=self._model_source_dir(),
350+
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
351+
env={'SAGEMAKER_REQUIREMENTS': self.requirements_file},
352+
image=self.image_name,
353+
name=self._current_job_name,
354+
container_log_level=self.container_log_level,
318355
code_location=self.code_location, py_version=self.py_version,
319-
framework_version=self.framework_version, model_server_workers=model_server_workers,
356+
framework_version=self.framework_version,
357+
model_server_workers=model_server_workers,
320358
sagemaker_session=self.sagemaker_session,
321359
vpc_config=self.get_vpc_config(vpc_config_override))
322360

src/sagemaker/tensorflow/serving.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import logging
16+
17+
import sagemaker
18+
from sagemaker.content_types import CONTENT_TYPE_JSON
19+
from sagemaker.fw_utils import create_image_uri
20+
from sagemaker.predictor import json_serializer, json_deserializer
21+
from sagemaker.tensorflow.defaults import TF_VERSION
22+
23+
24+
class Predictor(sagemaker.RealTimePredictor):
25+
"""A ``RealTimePredictor`` implementation for inference against TensorFlow Serving endpoints.
26+
"""
27+
28+
def __init__(self, endpoint_name, sagemaker_session=None,
29+
serializer=json_serializer,
30+
deserializer=json_deserializer,
31+
model_name=None,
32+
model_version=None):
33+
"""Initialize a ``TFSPredictor``. See ``sagemaker.RealTimePredictor`` for
34+
more info about parameters.
35+
36+
Args:
37+
endpoint_name (str): The name of the endpoint to perform inference on.
38+
sagemaker_session (sagemaker.session.Session): Session object which manages interactions
39+
with Amazon SageMaker APIs and any other AWS services needed. If not specified,
40+
the estimator creates one using the default AWS configuration chain.
41+
serializer (callable): Optional. Default serializes input data to json. Handles dicts,
42+
lists, and numpy arrays.
43+
deserializer (callable): Optional. Default parses the response using ``json.load(...)``.
44+
model_name (str): Optional. The name of the SavedModel model that should handle the
45+
request. If not specified, the endpoint's default model will handle the request.
46+
model_version (str): Optional. The version of the SavedModel model that should handle
47+
the request. If not specified, the latest version of the model will be used.
48+
"""
49+
super(Predictor, self).__init__(endpoint_name, sagemaker_session, serializer,
50+
deserializer)
51+
52+
attributes = []
53+
if model_name:
54+
attributes.append('tfs-model-name={}'.format(model_name))
55+
if model_version:
56+
attributes.append('tfs-model-version={}'.format(model_version))
57+
self._model_attributes = ','.join(attributes) if attributes else None
58+
59+
def classify(self, data):
60+
return self._classify_or_regress(data, 'classify')
61+
62+
def regress(self, data):
63+
return self._classify_or_regress(data, 'regress')
64+
65+
def _classify_or_regress(self, data, method):
66+
if method not in ['classify', 'regress']:
67+
raise ValueError('invalid TensorFlow Serving method: {}'.format(method))
68+
69+
if self.content_type != CONTENT_TYPE_JSON:
70+
raise ValueError('The {} api requires json requests.'.format(method))
71+
72+
args = {
73+
'CustomAttributes': 'tfs-method={}'.format(method)
74+
}
75+
76+
return self.predict(data, args)
77+
78+
def predict(self, data, initial_args=None):
79+
args = dict(initial_args) if initial_args else {}
80+
if self._model_attributes:
81+
if 'CustomAttributes' in args:
82+
args['CustomAttributes'] += ',' + self._model_attributes
83+
else:
84+
args['CustomAttributes'] = self._model_attributes
85+
86+
return super(Predictor, self).predict(data, args)
87+
88+
89+
class Model(sagemaker.Model):
90+
FRAMEWORK_NAME = 'tensorflow-serving'
91+
LOG_LEVEL_PARAM_NAME = 'SAGEMAKER_TFS_NGINX_LOGLEVEL'
92+
LOG_LEVEL_MAP = {
93+
logging.DEBUG: 'debug',
94+
logging.INFO: 'info',
95+
logging.WARNING: 'warn',
96+
logging.ERROR: 'error',
97+
logging.CRITICAL: 'crit',
98+
}
99+
100+
def __init__(self, model_data, role, image=None, framework_version=TF_VERSION,
101+
container_log_level=None, predictor_cls=Predictor, **kwargs):
102+
"""Initialize a Model.
103+
104+
Args:
105+
model_data (str): The S3 location of a SageMaker model data ``.tar.gz`` file.
106+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker APIs that
107+
create Amazon SageMaker endpoints use this role to access model artifacts.
108+
image (str): A Docker image URI (default: None). If not specified, a default image for
109+
TensorFlow Serving will be used.
110+
framework_version (str): Optional. TensorFlow Serving version you want to use.
111+
container_log_level (int): Log level to use within the container (default: logging.ERROR).
112+
Valid values are defined in the Python logging module.
113+
predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a
114+
predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()``
115+
returns the result of invoking this function on the created endpoint name.
116+
**kwargs: Keyword arguments passed to the ``Model`` initializer.
117+
"""
118+
super(Model, self).__init__(model_data=model_data, role=role, image=image,
119+
predictor_cls=predictor_cls, **kwargs)
120+
self._framework_version = framework_version
121+
self._container_log_level = container_log_level
122+
123+
def prepare_container_def(self, instance_type):
124+
image = self._get_image_uri(instance_type)
125+
env = self._get_container_env()
126+
return sagemaker.container_def(image, self.model_data, env)
127+
128+
def _get_container_env(self):
129+
if not self._container_log_level:
130+
return self.env
131+
132+
if self._container_log_level not in Model.LOG_LEVEL_MAP:
133+
logging.warning('ignoring invalid container log level: %s', self._container_log_level)
134+
return self.env
135+
136+
env = dict(self.env)
137+
env['SAGEMAKER_TFS_NGINX_LOGLEVEL'] = Model.LOG_LEVEL_MAP[self._container_log_level]
138+
return env
139+
140+
def _get_image_uri(self, instance_type):
141+
if self.image:
142+
return self.image
143+
144+
# reuse standard image uri function, then strip unwanted python component
145+
region_name = self.sagemaker_session.boto_region_name
146+
image = create_image_uri(region_name, Model.FRAMEWORK_NAME, instance_type,
147+
self._framework_version, 'py3')
148+
image = image.replace('-py3', '')
149+
return image
3.13 KB
Binary file not shown.

0 commit comments

Comments
 (0)