Skip to content

add tfs container support #460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
CHANGELOG
=========

1.13.1.dev
1.14.0-dev
==========

* feature: add support for sagemaker-tensorflow-serving container
* feature: Estimator: make input channels optional


1.13.0
======

Expand Down
37 changes: 24 additions & 13 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,35 +57,28 @@ def __init__(self, endpoint, sagemaker_session=None, serializer=None, deserializ
self.content_type = content_type or getattr(serializer, 'content_type', None)
self.accept = accept or getattr(deserializer, 'accept', None)

def predict(self, data):
def predict(self, data, initial_args=None):
"""Return the inference from the specified endpoint.

Args:
data (object): Input data for which you want the model to provide inference.
If a serializer was specified when creating the RealTimePredictor, the result of the
serializer is sent as input data. Otherwise the data must be sequence of bytes, and
the predict method then sends the bytes in the request body as is.
initial_args (dict[str,str]): Optional. Default arguments for boto3
``invoke_endpoint`` call. Default is None (no default arguments).

Returns:
object: Inference for the given input. If a deserializer was specified when creating
the RealTimePredictor, the result of the deserializer is returned. Otherwise the response
returns the sequence of bytes as is.
"""
if self.serializer is not None:
data = self.serializer(data)

request_args = {
'EndpointName': self.endpoint,
'Body': data
}

if self.content_type:
request_args['ContentType'] = self.content_type
if self.accept:
request_args['Accept'] = self.accept

request_args = self._create_request_args(data, initial_args)
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
return self._handle_response(response)

def _handle_response(self, response):
response_body = response['Body']
if self.deserializer is not None:
# It's the deserializer's responsibility to close the stream
Expand All @@ -94,6 +87,24 @@ def predict(self, data):
response_body.close()
return data

def _create_request_args(self, data, initial_args=None):
args = dict(initial_args) if initial_args else {}

if 'EndpointName' not in args:
args['EndpointName'] = self.endpoint

if self.content_type and 'ContentType' not in args:
args['ContentType'] = self.content_type

if self.accept and 'Accept' not in args:
args['Accept'] = self.accept

if self.serializer is not None:
data = self.serializer(data)

args['Body'] = data
return args

def delete_endpoint(self):
"""Delete the Amazon SageMaker endpoint backing this predictor.
"""
Expand Down
80 changes: 59 additions & 21 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
import time

from sagemaker.estimator import Framework
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning
from sagemaker.utils import get_config_value
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, \
empty_framework_version_warning
from sagemaker.tensorflow.defaults import TF_VERSION
from sagemaker.tensorflow.model import TensorFlowModel
from sagemaker.tensorflow.serving import Model
from sagemaker.utils import get_config_value
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

logging.basicConfig()
LOGGER = logging.getLogger('sagemaker')
Expand Down Expand Up @@ -103,12 +104,14 @@ def validate_requirements(self):
EnvironmentError: If at least one requirement is not installed.
"""
if not self._cmd_exists('tensorboard'):
raise EnvironmentError('TensorBoard is not installed in the system. Please install TensorBoard using the'
' following command: \n pip install tensorboard')
raise EnvironmentError(
'TensorBoard is not installed in the system. Please install TensorBoard using the'
' following command: \n pip install tensorboard')

if not self._cmd_exists('aws'):
raise EnvironmentError('The AWS CLI is not installed in the system. Please install the AWS CLI using the'
' following command: \n pip install awscli')
raise EnvironmentError(
'The AWS CLI is not installed in the system. Please install the AWS CLI using the'
' following command: \n pip install awscli')

def create_tensorboard_process(self):
"""Create a TensorBoard process.
Expand All @@ -125,7 +128,8 @@ def create_tensorboard_process(self):

for i in range(100):
p = subprocess.Popen(
["tensorboard", "--logdir", self.logdir, "--host", "localhost", "--port", str(port)],
["tensorboard", "--logdir", self.logdir, "--host", "localhost", "--port",
str(port)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
Expand All @@ -135,7 +139,8 @@ def create_tensorboard_process(self):
else:
return port, p

raise OSError('No available ports to start TensorBoard. Attempted all ports between 6006 and 6105')
raise OSError(
'No available ports to start TensorBoard. Attempted all ports between 6006 and 6105')

def run(self):
"""Run TensorBoard process."""
Expand All @@ -158,7 +163,8 @@ class TensorFlow(Framework):

__framework_name__ = 'tensorflow'

def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2',
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None,
py_version='py2',
framework_version=None, requirements_file='', image_name=None, **kwargs):
"""Initialize an ``TensorFlow`` estimator.
Args:
Expand Down Expand Up @@ -202,7 +208,8 @@ def _validate_requirements_file(self, requirements_file):
raise ValueError('Must specify source_dir along with a requirements file.')

if os.path.isabs(requirements_file):
raise ValueError('Requirements file {} is not a path relative to source_dir.'.format(requirements_file))
raise ValueError('Requirements file {} is not a path relative to source_dir.'.format(
requirements_file))

if not os.path.exists(os.path.join(self.source_dir, requirements_file)):
raise ValueError('Requirements file {} does not exist.'.format(requirements_file))
Expand Down Expand Up @@ -231,6 +238,7 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None, run_tensorboard_
downloaded checkpoint information (default: False). This is an experimental feature, and requires
TensorBoard and AWS CLI to be installed. It terminates TensorBoard when execution ends.
"""

def fit_super():
super(TensorFlow, self).fit(inputs, wait, logs, job_name)

Expand Down Expand Up @@ -263,7 +271,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
dictionary: The transformed init_params

"""
init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description(job_details, model_channel_name)
init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description(job_details,
model_channel_name)

# Move some of the tensorflow specific init params from hyperparameters into the main init params.
for argument in ['checkpoint_path', 'training_steps', 'evaluation_steps']:
Expand All @@ -285,15 +294,18 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
# containing framework version, device type and python version (e.g. '1.5-gpu-py2').
# For backward compatibility map deprecated image tag '1.0' to a '1.4' framework version
# otherwise extract framework version from the tag itself.
init_params['framework_version'] = '1.4' if tag == '1.0' else framework_version_from_tag(tag)
init_params['framework_version'] = '1.4' if tag == '1.0' else framework_version_from_tag(
tag)

training_job_name = init_params['base_job_name']
if framework != cls.__framework_name__:
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))
raise ValueError("Training job: {} didn't use image for requested framework".format(
training_job_name))

return init_params

def create_model(self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, model_server_workers=None, role=None,
vpc_config_override=VPC_CONFIG_DEFAULT, endpoint_type=None):
"""Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``.

Args:
Expand All @@ -305,18 +317,44 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override
Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
endpoint_type: Optional. Selects the software stack used by the inference server.
If not specified, the model will be configured to use the default
SageMaker model server. If 'tensorflow-serving', the model will be configured to
use the SageMaker Tensorflow Serving container.

Returns:
sagemaker.tensorflow.model.TensorFlowModel: A SageMaker ``TensorFlowModel`` object.
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
"""
env = {'SAGEMAKER_REQUIREMENTS': self.requirements_file}

role = role or self.role
return TensorFlowModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, env=env, image=self.image_name,
name=self._current_job_name, container_log_level=self.container_log_level,
if endpoint_type == 'tensorflow-serving':
return self._create_tfs_model(role=role, vpc_config_override=vpc_config_override)

return self._create_default_model(model_server_workers=model_server_workers, role=role,
vpc_config_override=vpc_config_override)

def _create_tfs_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
return Model(model_data=self.model_data,
role=role,
image=self.image_name,
name=self._current_job_name,
container_log_level=self.container_log_level,
framework_version=self.framework_version,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override))

def _create_default_model(self, model_server_workers, role, vpc_config_override):
return TensorFlowModel(self.model_data, role, self.entry_point,
source_dir=self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
env={'SAGEMAKER_REQUIREMENTS': self.requirements_file},
image=self.image_name,
name=self._current_job_name,
container_log_level=self.container_log_level,
code_location=self.code_location, py_version=self.py_version,
framework_version=self.framework_version, model_server_workers=model_server_workers,
framework_version=self.framework_version,
model_server_workers=model_server_workers,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override))

Expand Down
149 changes: 149 additions & 0 deletions src/sagemaker/tensorflow/serving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import logging

import sagemaker
from sagemaker.content_types import CONTENT_TYPE_JSON
from sagemaker.fw_utils import create_image_uri
from sagemaker.predictor import json_serializer, json_deserializer
from sagemaker.tensorflow.defaults import TF_VERSION


class Predictor(sagemaker.RealTimePredictor):
"""A ``RealTimePredictor`` implementation for inference against TensorFlow Serving endpoints.
"""

def __init__(self, endpoint_name, sagemaker_session=None,
serializer=json_serializer,
deserializer=json_deserializer,
model_name=None,
model_version=None):
"""Initialize a ``TFSPredictor``. See ``sagemaker.RealTimePredictor`` for
more info about parameters.

Args:
endpoint_name (str): The name of the endpoint to perform inference on.
sagemaker_session (sagemaker.session.Session): Session object which manages interactions
with Amazon SageMaker APIs and any other AWS services needed. If not specified,
the estimator creates one using the default AWS configuration chain.
serializer (callable): Optional. Default serializes input data to json. Handles dicts,
lists, and numpy arrays.
deserializer (callable): Optional. Default parses the response using ``json.load(...)``.
model_name (str): Optional. The name of the SavedModel model that should handle the
request. If not specified, the endpoint's default model will handle the request.
model_version (str): Optional. The version of the SavedModel model that should handle
the request. If not specified, the latest version of the model will be used.
"""
super(Predictor, self).__init__(endpoint_name, sagemaker_session, serializer,
deserializer)

attributes = []
if model_name:
attributes.append('tfs-model-name={}'.format(model_name))
if model_version:
attributes.append('tfs-model-version={}'.format(model_version))
self._model_attributes = ','.join(attributes) if attributes else None

def classify(self, data):
return self._classify_or_regress(data, 'classify')

def regress(self, data):
return self._classify_or_regress(data, 'regress')

def _classify_or_regress(self, data, method):
if method not in ['classify', 'regress']:
raise ValueError('invalid TensorFlow Serving method: {}'.format(method))

if self.content_type != CONTENT_TYPE_JSON:
raise ValueError('The {} api requires json requests.'.format(method))

args = {
'CustomAttributes': 'tfs-method={}'.format(method)
}

return self.predict(data, args)

def predict(self, data, initial_args=None):
args = dict(initial_args) if initial_args else {}
if self._model_attributes:
if 'CustomAttributes' in args:
args['CustomAttributes'] += ',' + self._model_attributes
else:
args['CustomAttributes'] = self._model_attributes

return super(Predictor, self).predict(data, args)


class Model(sagemaker.Model):
FRAMEWORK_NAME = 'tensorflow-serving'
LOG_LEVEL_PARAM_NAME = 'SAGEMAKER_TFS_NGINX_LOGLEVEL'
LOG_LEVEL_MAP = {
logging.DEBUG: 'debug',
logging.INFO: 'info',
logging.WARNING: 'warn',
logging.ERROR: 'error',
logging.CRITICAL: 'crit',
}

def __init__(self, model_data, role, image=None, framework_version=TF_VERSION,
container_log_level=None, predictor_cls=Predictor, **kwargs):
"""Initialize a Model.

Args:
model_data (str): The S3 location of a SageMaker model data ``.tar.gz`` file.
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker APIs that
create Amazon SageMaker endpoints use this role to access model artifacts.
image (str): A Docker image URI (default: None). If not specified, a default image for
TensorFlow Serving will be used.
framework_version (str): Optional. TensorFlow Serving version you want to use.
container_log_level (int): Log level to use within the container (default: logging.ERROR).
Valid values are defined in the Python logging module.
predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a
predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()``
returns the result of invoking this function on the created endpoint name.
**kwargs: Keyword arguments passed to the ``Model`` initializer.
"""
super(Model, self).__init__(model_data=model_data, role=role, image=image,
predictor_cls=predictor_cls, **kwargs)
self._framework_version = framework_version
self._container_log_level = container_log_level

def prepare_container_def(self, instance_type):
image = self._get_image_uri(instance_type)
env = self._get_container_env()
return sagemaker.container_def(image, self.model_data, env)

def _get_container_env(self):
if not self._container_log_level:
return self.env

if self._container_log_level not in Model.LOG_LEVEL_MAP:
logging.warning('ignoring invalid container log level: %s', self._container_log_level)
return self.env

env = dict(self.env)
env['SAGEMAKER_TFS_NGINX_LOGLEVEL'] = Model.LOG_LEVEL_MAP[self._container_log_level]
return env

def _get_image_uri(self, instance_type):
if self.image:
return self.image

# reuse standard image uri function, then strip unwanted python component
region_name = self.sagemaker_session.boto_region_name
image = create_image_uri(region_name, Model.FRAMEWORK_NAME, instance_type,
self._framework_version, 'py3')
image = image.replace('-py3', '')
return image
Binary file added tests/data/tensorflow-serving-test-model.tar.gz
Binary file not shown.
Loading