Skip to content

add tfs container support #460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 7, 2018
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
CHANGELOG
=========

1.14.0-dev
==========

* feature: add support for sagemaker-tfs container

1.13.0
======

Expand Down
37 changes: 24 additions & 13 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,35 +57,28 @@ def __init__(self, endpoint, sagemaker_session=None, serializer=None, deserializ
self.content_type = content_type or getattr(serializer, 'content_type', None)
self.accept = accept or getattr(deserializer, 'accept', None)

def predict(self, data):
def predict(self, data, initial_args=None):
"""Return the inference from the specified endpoint.

Args:
data (object): Input data for which you want the model to provide inference.
If a serializer was specified when creating the RealTimePredictor, the result of the
serializer is sent as input data. Otherwise the data must be sequence of bytes, and
the predict method then sends the bytes in the request body as is.
initial_args (dict[str,str]): Optional. Default arguments for boto3
``invoke_endpoint`` call. Default is None (no default arguments).

Returns:
object: Inference for the given input. If a deserializer was specified when creating
the RealTimePredictor, the result of the deserializer is returned. Otherwise the response
returns the sequence of bytes as is.
"""
if self.serializer is not None:
data = self.serializer(data)

request_args = {
'EndpointName': self.endpoint,
'Body': data
}

if self.content_type:
request_args['ContentType'] = self.content_type
if self.accept:
request_args['Accept'] = self.accept

request_args = self._create_request_args(data, initial_args)
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
return self._handle_response(response)

def _handle_response(self, response):
response_body = response['Body']
if self.deserializer is not None:
# It's the deserializer's responsibility to close the stream
Expand All @@ -94,6 +87,24 @@ def predict(self, data):
response_body.close()
return data

def _create_request_args(self, data, initial_args=None):
args = dict(initial_args) if initial_args else {}

if 'EndpointName' not in args:
args['EndpointName'] = self.endpoint

if self.content_type and 'ContentType' not in args:
args['ContentType'] = self.content_type

if self.accept and 'Accept' not in args:
args['Accept'] = self.accept

if self.serializer is not None:
data = self.serializer(data)

args['Body'] = data
return args

def delete_endpoint(self):
"""Delete the Amazon SageMaker endpoint backing this predictor.
"""
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@

from sagemaker.tensorflow.estimator import TensorFlow # noqa: E402, F401
from sagemaker.tensorflow.model import TensorFlowModel, TensorFlowPredictor # noqa: E402, F401
from sagemaker.tensorflow.tfs import TFSModel, TFSPredictor # noqa: E402, F401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid acronyms to keep the same pattern of the SDK

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'TensorFlowServingModel' is unwieldy, and ambiguous since TensorFlowModel also refers to a kind of TensorFlow Serving model. TFS is well aligned with the module name, and the container side package name.

Another option... just calling them Model and Predictor (sagemaker.tensorflow.tfs.Model etc), but this invites collision with sagemaker.Model etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about sagemaker.tesorflow.serving.Model or sagemaker.tensorflow_serving.Model?

Collision should not happen if we are importing modules not classes like specified in the style guide.

79 changes: 58 additions & 21 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
import time

from sagemaker.estimator import Framework
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning
from sagemaker.utils import get_config_value
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, \
empty_framework_version_warning
from sagemaker.tensorflow.defaults import TF_VERSION
from sagemaker.tensorflow.model import TensorFlowModel
from sagemaker.tensorflow.tfs import TFSModel
from sagemaker.utils import get_config_value
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

logging.basicConfig()
LOGGER = logging.getLogger('sagemaker')
Expand Down Expand Up @@ -103,12 +104,14 @@ def validate_requirements(self):
EnvironmentError: If at least one requirement is not installed.
"""
if not self._cmd_exists('tensorboard'):
raise EnvironmentError('TensorBoard is not installed in the system. Please install TensorBoard using the'
' following command: \n pip install tensorboard')
raise EnvironmentError(
'TensorBoard is not installed in the system. Please install TensorBoard using the'
' following command: \n pip install tensorboard')

if not self._cmd_exists('aws'):
raise EnvironmentError('The AWS CLI is not installed in the system. Please install the AWS CLI using the'
' following command: \n pip install awscli')
raise EnvironmentError(
'The AWS CLI is not installed in the system. Please install the AWS CLI using the'
' following command: \n pip install awscli')

def create_tensorboard_process(self):
"""Create a TensorBoard process.
Expand All @@ -125,7 +128,8 @@ def create_tensorboard_process(self):

for i in range(100):
p = subprocess.Popen(
["tensorboard", "--logdir", self.logdir, "--host", "localhost", "--port", str(port)],
["tensorboard", "--logdir", self.logdir, "--host", "localhost", "--port",
str(port)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
Expand All @@ -135,7 +139,8 @@ def create_tensorboard_process(self):
else:
return port, p

raise OSError('No available ports to start TensorBoard. Attempted all ports between 6006 and 6105')
raise OSError(
'No available ports to start TensorBoard. Attempted all ports between 6006 and 6105')

def run(self):
"""Run TensorBoard process."""
Expand All @@ -158,7 +163,8 @@ class TensorFlow(Framework):

__framework_name__ = 'tensorflow'

def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2',
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None,
py_version='py2',
framework_version=None, requirements_file='', image_name=None, **kwargs):
"""Initialize an ``TensorFlow`` estimator.
Args:
Expand Down Expand Up @@ -202,7 +208,8 @@ def _validate_requirements_file(self, requirements_file):
raise ValueError('Must specify source_dir along with a requirements file.')

if os.path.isabs(requirements_file):
raise ValueError('Requirements file {} is not a path relative to source_dir.'.format(requirements_file))
raise ValueError('Requirements file {} is not a path relative to source_dir.'.format(
requirements_file))

if not os.path.exists(os.path.join(self.source_dir, requirements_file)):
raise ValueError('Requirements file {} does not exist.'.format(requirements_file))
Expand Down Expand Up @@ -231,6 +238,7 @@ def fit(self, inputs, wait=True, logs=True, job_name=None, run_tensorboard_local
downloaded checkpoint information (default: False). This is an experimental feature, and requires
TensorBoard and AWS CLI to be installed. It terminates TensorBoard when execution ends.
"""

def fit_super():
super(TensorFlow, self).fit(inputs, wait, logs, job_name)

Expand Down Expand Up @@ -263,7 +271,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
dictionary: The transformed init_params

"""
init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description(job_details, model_channel_name)
init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description(job_details,
model_channel_name)

# Move some of the tensorflow specific init params from hyperparameters into the main init params.
for argument in ['checkpoint_path', 'training_steps', 'evaluation_steps']:
Expand All @@ -285,15 +294,18 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
# containing framework version, device type and python version (e.g. '1.5-gpu-py2').
# For backward compatibility map deprecated image tag '1.0' to a '1.4' framework version
# otherwise extract framework version from the tag itself.
init_params['framework_version'] = '1.4' if tag == '1.0' else framework_version_from_tag(tag)
init_params['framework_version'] = '1.4' if tag == '1.0' else framework_version_from_tag(
tag)

training_job_name = init_params['base_job_name']
if framework != cls.__framework_name__:
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))
raise ValueError("Training job: {} didn't use image for requested framework".format(
training_job_name))

return init_params

def create_model(self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, model_server_workers=None, role=None,
vpc_config_override=VPC_CONFIG_DEFAULT, endpoint_type=None):
"""Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``.

Args:
Expand All @@ -305,18 +317,43 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override
Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
endpoint_type: Optional. Selects the software stack used by the inference server.
If not specified, the model will be configured to use the default SageMaker model server.
If 'tfs', the model will be configured to use the SageMaker Tensorflow Serving container.

Returns:
sagemaker.tensorflow.model.TensorFlowModel: A SageMaker ``TensorFlowModel`` object.
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
"""
env = {'SAGEMAKER_REQUIREMENTS': self.requirements_file}

role = role or self.role
return TensorFlowModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, env=env, image=self.image_name,
name=self._current_job_name, container_log_level=self.container_log_level,
if endpoint_type == 'tfs':
return self._create_tfs_model(role=role, vpc_config_override=vpc_config_override)

return self._create_default_model(model_server_workers=model_server_workers, role=role,
vpc_config_override=vpc_config_override)

def _create_tfs_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
return TFSModel(model_data=self.model_data,
role=role,
image=self.image_name,
name=self._current_job_name,
container_log_level=self.container_log_level,
framework_version=self.framework_version,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you missing container_log_level?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


def _create_default_model(self, model_server_workers, role, vpc_config_override):
return TensorFlowModel(self.model_data, role, self.entry_point,
source_dir=self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
env={'SAGEMAKER_REQUIREMENTS': self.requirements_file},
image=self.image_name,
name=self._current_job_name,
container_log_level=self.container_log_level,
code_location=self.code_location, py_version=self.py_version,
framework_version=self.framework_version, model_server_workers=model_server_workers,
framework_version=self.framework_version,
model_server_workers=model_server_workers,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override))

Expand Down
Loading