Skip to content

add tfs container support #460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 7, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,35 +57,27 @@ def __init__(self, endpoint, sagemaker_session=None, serializer=None, deserializ
self.content_type = content_type or getattr(serializer, 'content_type', None)
self.accept = accept or getattr(deserializer, 'accept', None)

def predict(self, data):
def predict(self, data, initial_args=None):
"""Return the inference from the specified endpoint.

Args:
data (object): Input data for which you want the model to provide inference.
If a serializer was specified when creating the RealTimePredictor, the result of the
serializer is sent as input data. Otherwise the data must be sequence of bytes, and
the predict method then sends the bytes in the request body as is.
initial_args (dict[str,str]): Optional. Initial request arguments. Default is None.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clarify more this docstring argument.


Returns:
object: Inference for the given input. If a deserializer was specified when creating
the RealTimePredictor, the result of the deserializer is returned. Otherwise the response
returns the sequence of bytes as is.
"""
if self.serializer is not None:
data = self.serializer(data)

request_args = {
'EndpointName': self.endpoint,
'Body': data
}

if self.content_type:
request_args['ContentType'] = self.content_type
if self.accept:
request_args['Accept'] = self.accept

request_args = self._create_request_args(data, initial_args)
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
return self._handle_response(response)

def _handle_response(self, response):
response_body = response['Body']
if self.deserializer is not None:
# It's the deserializer's responsibility to close the stream
Expand All @@ -94,6 +86,24 @@ def predict(self, data):
response_body.close()
return data

def _create_request_args(self, data, initial_args=None):
args = dict(initial_args) if initial_args else {}

if 'EndpointName' not in args:
args['EndpointName'] = self.endpoint

if self.content_type and 'ContentType' not in args:
args['ContentType'] = self.content_type

if self.accept and 'Accept' not in args:
args['Accept'] = self.accept

if self.serializer is not None:
data = self.serializer(data)

args['Body'] = data
return args

def delete_endpoint(self):
"""Delete the Amazon SageMaker endpoint backing this predictor.
"""
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@

from sagemaker.tensorflow.estimator import TensorFlow # noqa: E402, F401
from sagemaker.tensorflow.model import TensorFlowModel, TensorFlowPredictor # noqa: E402, F401
from sagemaker.tensorflow.tfs import TFSModel, TFSPredictor # noqa: E402, F401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid acronyms to keep the same pattern of the SDK

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'TensorFlowServingModel' is unwieldy, and ambiguous since TensorFlowModel also refers to a kind of TensorFlow Serving model. TFS is well aligned with the module name, and the container side package name.

Another option... just calling them Model and Predictor (sagemaker.tensorflow.tfs.Model etc), but this invites collision with sagemaker.Model etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about sagemaker.tesorflow.serving.Model or sagemaker.tensorflow_serving.Model?

Collision should not happen if we are importing modules not classes like specified in the style guide.

78 changes: 57 additions & 21 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
import time

from sagemaker.estimator import Framework
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning
from sagemaker.utils import get_config_value
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, \
empty_framework_version_warning
from sagemaker.tensorflow.defaults import TF_VERSION
from sagemaker.tensorflow.model import TensorFlowModel
from sagemaker.tensorflow.tfs import TFSModel
from sagemaker.utils import get_config_value
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

logging.basicConfig()
LOGGER = logging.getLogger('sagemaker')
Expand Down Expand Up @@ -103,12 +104,14 @@ def validate_requirements(self):
EnvironmentError: If at least one requirement is not installed.
"""
if not self._cmd_exists('tensorboard'):
raise EnvironmentError('TensorBoard is not installed in the system. Please install TensorBoard using the'
' following command: \n pip install tensorboard')
raise EnvironmentError(
'TensorBoard is not installed in the system. Please install TensorBoard using the'
' following command: \n pip install tensorboard')

if not self._cmd_exists('aws'):
raise EnvironmentError('The AWS CLI is not installed in the system. Please install the AWS CLI using the'
' following command: \n pip install awscli')
raise EnvironmentError(
'The AWS CLI is not installed in the system. Please install the AWS CLI using the'
' following command: \n pip install awscli')

def create_tensorboard_process(self):
"""Create a TensorBoard process.
Expand All @@ -125,7 +128,8 @@ def create_tensorboard_process(self):

for i in range(100):
p = subprocess.Popen(
["tensorboard", "--logdir", self.logdir, "--host", "localhost", "--port", str(port)],
["tensorboard", "--logdir", self.logdir, "--host", "localhost", "--port",
str(port)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
Expand All @@ -135,7 +139,8 @@ def create_tensorboard_process(self):
else:
return port, p

raise OSError('No available ports to start TensorBoard. Attempted all ports between 6006 and 6105')
raise OSError(
'No available ports to start TensorBoard. Attempted all ports between 6006 and 6105')

def run(self):
"""Run TensorBoard process."""
Expand All @@ -158,7 +163,8 @@ class TensorFlow(Framework):

__framework_name__ = 'tensorflow'

def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2',
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None,
py_version='py2',
framework_version=None, requirements_file='', image_name=None, **kwargs):
"""Initialize an ``TensorFlow`` estimator.
Args:
Expand Down Expand Up @@ -202,7 +208,8 @@ def _validate_requirements_file(self, requirements_file):
raise ValueError('Must specify source_dir along with a requirements file.')

if os.path.isabs(requirements_file):
raise ValueError('Requirements file {} is not a path relative to source_dir.'.format(requirements_file))
raise ValueError('Requirements file {} is not a path relative to source_dir.'.format(
requirements_file))

if not os.path.exists(os.path.join(self.source_dir, requirements_file)):
raise ValueError('Requirements file {} does not exist.'.format(requirements_file))
Expand Down Expand Up @@ -231,6 +238,7 @@ def fit(self, inputs, wait=True, logs=True, job_name=None, run_tensorboard_local
downloaded checkpoint information (default: False). This is an experimental feature, and requires
TensorBoard and AWS CLI to be installed. It terminates TensorBoard when execution ends.
"""

def fit_super():
super(TensorFlow, self).fit(inputs, wait, logs, job_name)

Expand Down Expand Up @@ -263,7 +271,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
dictionary: The transformed init_params

"""
init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description(job_details, model_channel_name)
init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description(job_details,
model_channel_name)

# Move some of the tensorflow specific init params from hyperparameters into the main init params.
for argument in ['checkpoint_path', 'training_steps', 'evaluation_steps']:
Expand All @@ -285,15 +294,18 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
# containing framework version, device type and python version (e.g. '1.5-gpu-py2').
# For backward compatibility map deprecated image tag '1.0' to a '1.4' framework version
# otherwise extract framework version from the tag itself.
init_params['framework_version'] = '1.4' if tag == '1.0' else framework_version_from_tag(tag)
init_params['framework_version'] = '1.4' if tag == '1.0' else framework_version_from_tag(
tag)

training_job_name = init_params['base_job_name']
if framework != cls.__framework_name__:
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))
raise ValueError("Training job: {} didn't use image for requested framework".format(
training_job_name))

return init_params

def create_model(self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, model_server_workers=None, role=None,
vpc_config_override=VPC_CONFIG_DEFAULT, endpoint_type=None):
"""Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``.

Args:
Expand All @@ -305,18 +317,42 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override
Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
endpoint_type: Optional. Selects the software stack used by the inference server.
If not specified, the model will be configured to use the default SageMaker model server.
If 'tfs', the model will be configured to use the SageMaker Tensorflow Serving container.

Returns:
sagemaker.tensorflow.model.TensorFlowModel: A SageMaker ``TensorFlowModel`` object.
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
"""
env = {'SAGEMAKER_REQUIREMENTS': self.requirements_file}

role = role or self.role
return TensorFlowModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, env=env, image=self.image_name,
name=self._current_job_name, container_log_level=self.container_log_level,
if endpoint_type == 'tfs':
return self._create_tfs_model(role=role, vpc_config_override=vpc_config_override)

return self._create_default_model(model_server_workers=model_server_workers, role=role,
vpc_config_override=vpc_config_override)

def _create_tfs_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
return TFSModel(model_data=self.model_data,
role=role,
image=self.image_name,
name=self._current_job_name,
framework_version=self.framework_version,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you missing container_log_level?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


def _create_default_model(self, model_server_workers, role, vpc_config_override):
return TensorFlowModel(self.model_data, role, self.entry_point,
source_dir=self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
env={'SAGEMAKER_REQUIREMENTS': self.requirements_file},
image=self.image_name,
name=self._current_job_name,
container_log_level=self.container_log_level,
code_location=self.code_location, py_version=self.py_version,
framework_version=self.framework_version, model_server_workers=model_server_workers,
framework_version=self.framework_version,
model_server_workers=model_server_workers,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override))

Expand Down
150 changes: 150 additions & 0 deletions src/sagemaker/tensorflow/tfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import logging

import sagemaker
from sagemaker import Model, RealTimePredictor
from sagemaker.content_types import CONTENT_TYPE_JSON
from sagemaker.fw_utils import create_image_uri
from sagemaker.predictor import json_serializer, json_deserializer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: from google python style guide -> import modules not functions and classes.

from sagemaker.tensorflow.defaults import TF_VERSION


class TFSPredictor(RealTimePredictor):
"""A ``RealTimePredictor`` implementation for inference against TFS endpoints.
"""

def __init__(self, endpoint_name, sagemaker_session=None,
serializer=json_serializer,
deserializer=json_deserializer,
model_name=None,
model_version=None):
"""Initialize a ``TFSPredictor``. See ``sagemaker.RealTimePredictor`` for
more info about parameters.

Args:
endpoint_name (str): The name of the endpoint to perform inference on.
sagemaker_session (sagemaker.session.Session): Session object which manages interactions
with Amazon SageMaker APIs and any other AWS services needed. If not specified,
the estimator creates one using the default AWS configuration chain.
serializer (callable): Optional. Default serializes input data to json. Handles dicts,
lists, and numpy arrays.
deserializer (callable): Optional. Default parses the response using ``json.load(...)``.
model_name (str): Optional. The name of the TFS model that should handle the request.
If not specified, the endpoint's default model will handle the request.
model_version (str): Optional. The version of the TFS model that should handle the
request. If not specified, the latest version of the model will be used.
"""
super(TFSPredictor, self).__init__(endpoint_name, sagemaker_session, serializer,
deserializer)

attributes = []
if model_name:
attributes.append('tfs-model-name={}'.format(model_name))
if model_version:
attributes.append('tfs-model-version={}'.format(model_version))
self._model_attributes = ','.join(attributes) if attributes else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to use an empty array to represent an empty list of model attributes instead of None?


def classify(self, data):
return self._classify_or_regress(data, 'classify')

def regress(self, data):
return self._classify_or_regress(data, 'regress')

def _classify_or_regress(self, data, method):
if method not in ['classify', 'regress']:
raise ValueError('invalid TensorFlow Serving method: {}'.format(method))

if self.content_type != CONTENT_TYPE_JSON:
raise ValueError('The {} api requires json requests.'.format(method))

args = {
'CustomAttributes': 'tfs-method={}'.format(method)
}

return self.predict(data, args)

def predict(self, data, initial_args=None):
args = dict(initial_args) if initial_args else {}
if self._model_attributes:
if 'CustomAttributes' in args:
args['CustomAttributes'] += ',' + self._model_attributes
else:
args['CustomAttributes'] = self._model_attributes

return super(TFSPredictor, self).predict(data, args)


class TFSModel(Model):
FRAMEWORK_NAME = 'tfs'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do believe that other classes use dunders instead of a constant, although I Iike your style better and matches the conventions https://github.com/google/styleguide/blob/gh-pages/pyguide.md#3164-guidelines-derived-from-guidos-recommendations

nit: consider leaving it at module level

LOG_LEVEL_PARAM_NAME = 'SAGEMAKER_TFS_NGINX_LOGLEVEL'
LOG_LEVEL_MAP = {
logging.DEBUG: 'debug',
logging.INFO: 'info',
logging.WARNING: 'warn',
logging.ERROR: 'error',
logging.CRITICAL: 'crit',
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dictionaries are not constants in Python.


def __init__(self, model_data, role, image=None, framework_version=TF_VERSION,
container_log_level=None, predictor_cls=TFSPredictor, **kwargs):
"""Initialize a TFSModel.

Args:
model_data (str): The S3 location of a SageMaker model data ``.tar.gz`` file.
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker APIs that
create Amazon SageMaker endpoints use this role to access model artifacts.
image (str): A Docker image URI (default: None). If not specified, a default image for
TensorFlow Serving will be used.
framework_version (str): Optional. TensorFlow Serving version you want to use.
container_log_level (int): Log level to use within the container (default: logging.ERROR).
Valid values are defined in the Python logging module.
predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a
predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()``
returns the result of invoking this function on the created endpoint name.
**kwargs: Keyword arguments passed to the ``Model`` initializer.
"""
super(TFSModel, self).__init__(model_data=model_data, role=role, image=image,
predictor_cls=predictor_cls, **kwargs)
self._framework_version = framework_version
self._container_log_level = container_log_level

def prepare_container_def(self, instance_type):
image = self._get_image_uri(instance_type)
env = self._get_container_env()
return sagemaker.container_def(image, self.model_data, env)

def _get_container_env(self):
if not self._container_log_level:
return self.env

if self._container_log_level not in TFSModel.LOG_LEVEL_MAP:
logging.warning('ignoring invalid container log level: %s', self._container_log_level)
return self.env

env = dict(self.env)
env['SAGEMAKER_TFS_NGINX_LOGLEVEL'] = TFSModel.LOG_LEVEL_MAP[self._container_log_level]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use a constant for this key.

return env

def _get_image_uri(self, instance_type):
if self.image:
return self.image

# reuse standard image uri function, then strip unwanted python component
region_name = self.sagemaker_session.boto_region_name
image = create_image_uri(region_name, TFSModel.FRAMEWORK_NAME, instance_type,
self._framework_version, 'py3')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How hard is to change create image uri instead of reusing and striping it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lots of fussy behavior in create_image_uri. it would be easy to fall out of sync

image = image.replace('-py3', '')
return image
Binary file added tests/data/tfs-test-model.tar.gz
Binary file not shown.
Loading