-
Notifications
You must be signed in to change notification settings - Fork 1.2k
add tfs container support #460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
d17252a
cb46dd2
75bf893
4c60c23
669d0db
755d082
61dd059
8d8d961
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,3 +29,4 @@ | |
|
||
from sagemaker.tensorflow.estimator import TensorFlow # noqa: E402, F401 | ||
from sagemaker.tensorflow.model import TensorFlowModel, TensorFlowPredictor # noqa: E402, F401 | ||
from sagemaker.tensorflow.tfs import TFSModel, TFSPredictor # noqa: E402, F401 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's avoid acronyms to keep the same pattern of the SDK There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 'TensorFlowServingModel' is unwieldy, and ambiguous since TensorFlowModel also refers to a kind of TensorFlow Serving model. TFS is well aligned with the module name, and the container side package name. Another option... just calling them Model and Predictor (sagemaker.tensorflow.tfs.Model etc), but this invites collision with sagemaker.Model etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about sagemaker.tesorflow.serving.Model or sagemaker.tensorflow_serving.Model? Collision should not happen if we are importing modules not classes like specified in the style guide. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,12 +22,13 @@ | |
import time | ||
|
||
from sagemaker.estimator import Framework | ||
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning | ||
from sagemaker.utils import get_config_value | ||
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT | ||
|
||
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, \ | ||
empty_framework_version_warning | ||
from sagemaker.tensorflow.defaults import TF_VERSION | ||
from sagemaker.tensorflow.model import TensorFlowModel | ||
from sagemaker.tensorflow.tfs import TFSModel | ||
from sagemaker.utils import get_config_value | ||
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT | ||
|
||
logging.basicConfig() | ||
LOGGER = logging.getLogger('sagemaker') | ||
|
@@ -103,12 +104,14 @@ def validate_requirements(self): | |
EnvironmentError: If at least one requirement is not installed. | ||
""" | ||
if not self._cmd_exists('tensorboard'): | ||
raise EnvironmentError('TensorBoard is not installed in the system. Please install TensorBoard using the' | ||
' following command: \n pip install tensorboard') | ||
raise EnvironmentError( | ||
'TensorBoard is not installed in the system. Please install TensorBoard using the' | ||
' following command: \n pip install tensorboard') | ||
|
||
if not self._cmd_exists('aws'): | ||
raise EnvironmentError('The AWS CLI is not installed in the system. Please install the AWS CLI using the' | ||
' following command: \n pip install awscli') | ||
raise EnvironmentError( | ||
'The AWS CLI is not installed in the system. Please install the AWS CLI using the' | ||
' following command: \n pip install awscli') | ||
|
||
def create_tensorboard_process(self): | ||
"""Create a TensorBoard process. | ||
|
@@ -125,7 +128,8 @@ def create_tensorboard_process(self): | |
|
||
for i in range(100): | ||
p = subprocess.Popen( | ||
["tensorboard", "--logdir", self.logdir, "--host", "localhost", "--port", str(port)], | ||
["tensorboard", "--logdir", self.logdir, "--host", "localhost", "--port", | ||
str(port)], | ||
stdout=subprocess.PIPE, | ||
stderr=subprocess.PIPE | ||
) | ||
|
@@ -135,7 +139,8 @@ def create_tensorboard_process(self): | |
else: | ||
return port, p | ||
|
||
raise OSError('No available ports to start TensorBoard. Attempted all ports between 6006 and 6105') | ||
raise OSError( | ||
'No available ports to start TensorBoard. Attempted all ports between 6006 and 6105') | ||
|
||
def run(self): | ||
"""Run TensorBoard process.""" | ||
|
@@ -158,7 +163,8 @@ class TensorFlow(Framework): | |
|
||
__framework_name__ = 'tensorflow' | ||
|
||
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2', | ||
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, | ||
py_version='py2', | ||
framework_version=None, requirements_file='', image_name=None, **kwargs): | ||
"""Initialize an ``TensorFlow`` estimator. | ||
Args: | ||
|
@@ -202,7 +208,8 @@ def _validate_requirements_file(self, requirements_file): | |
raise ValueError('Must specify source_dir along with a requirements file.') | ||
|
||
if os.path.isabs(requirements_file): | ||
raise ValueError('Requirements file {} is not a path relative to source_dir.'.format(requirements_file)) | ||
raise ValueError('Requirements file {} is not a path relative to source_dir.'.format( | ||
requirements_file)) | ||
|
||
if not os.path.exists(os.path.join(self.source_dir, requirements_file)): | ||
raise ValueError('Requirements file {} does not exist.'.format(requirements_file)) | ||
|
@@ -231,6 +238,7 @@ def fit(self, inputs, wait=True, logs=True, job_name=None, run_tensorboard_local | |
downloaded checkpoint information (default: False). This is an experimental feature, and requires | ||
TensorBoard and AWS CLI to be installed. It terminates TensorBoard when execution ends. | ||
""" | ||
|
||
def fit_super(): | ||
super(TensorFlow, self).fit(inputs, wait, logs, job_name) | ||
|
||
|
@@ -263,7 +271,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na | |
dictionary: The transformed init_params | ||
|
||
""" | ||
init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description(job_details, model_channel_name) | ||
init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description(job_details, | ||
model_channel_name) | ||
|
||
# Move some of the tensorflow specific init params from hyperparameters into the main init params. | ||
for argument in ['checkpoint_path', 'training_steps', 'evaluation_steps']: | ||
|
@@ -285,15 +294,18 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na | |
# containing framework version, device type and python version (e.g. '1.5-gpu-py2'). | ||
# For backward compatibility map deprecated image tag '1.0' to a '1.4' framework version | ||
# otherwise extract framework version from the tag itself. | ||
init_params['framework_version'] = '1.4' if tag == '1.0' else framework_version_from_tag(tag) | ||
init_params['framework_version'] = '1.4' if tag == '1.0' else framework_version_from_tag( | ||
tag) | ||
|
||
training_job_name = init_params['base_job_name'] | ||
if framework != cls.__framework_name__: | ||
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name)) | ||
raise ValueError("Training job: {} didn't use image for requested framework".format( | ||
training_job_name)) | ||
|
||
return init_params | ||
|
||
def create_model(self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT): | ||
def create_model(self, model_server_workers=None, role=None, | ||
vpc_config_override=VPC_CONFIG_DEFAULT, endpoint_type=None): | ||
"""Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``. | ||
|
||
Args: | ||
|
@@ -305,18 +317,42 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override | |
Default: use subnets and security groups from this Estimator. | ||
* 'Subnets' (list[str]): List of subnet ids. | ||
* 'SecurityGroupIds' (list[str]): List of security group ids. | ||
endpoint_type: Optional. Selects the software stack used by the inference server. | ||
If not specified, the model will be configured to use the default SageMaker model server. | ||
If 'tfs', the model will be configured to use the SageMaker Tensorflow Serving container. | ||
|
||
jesterhazy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Returns: | ||
sagemaker.tensorflow.model.TensorFlowModel: A SageMaker ``TensorFlowModel`` object. | ||
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details. | ||
""" | ||
env = {'SAGEMAKER_REQUIREMENTS': self.requirements_file} | ||
|
||
role = role or self.role | ||
return TensorFlowModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(), | ||
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, env=env, image=self.image_name, | ||
name=self._current_job_name, container_log_level=self.container_log_level, | ||
if endpoint_type == 'tfs': | ||
return self._create_tfs_model(role=role, vpc_config_override=vpc_config_override) | ||
|
||
return self._create_default_model(model_server_workers=model_server_workers, role=role, | ||
vpc_config_override=vpc_config_override) | ||
|
||
def _create_tfs_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT): | ||
return TFSModel(model_data=self.model_data, | ||
role=role, | ||
image=self.image_name, | ||
name=self._current_job_name, | ||
framework_version=self.framework_version, | ||
sagemaker_session=self.sagemaker_session, | ||
vpc_config=self.get_vpc_config(vpc_config_override)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you missing container_log_level? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
||
|
||
def _create_default_model(self, model_server_workers, role, vpc_config_override): | ||
return TensorFlowModel(self.model_data, role, self.entry_point, | ||
source_dir=self._model_source_dir(), | ||
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, | ||
env={'SAGEMAKER_REQUIREMENTS': self.requirements_file}, | ||
image=self.image_name, | ||
name=self._current_job_name, | ||
container_log_level=self.container_log_level, | ||
code_location=self.code_location, py_version=self.py_version, | ||
framework_version=self.framework_version, model_server_workers=model_server_workers, | ||
framework_version=self.framework_version, | ||
model_server_workers=model_server_workers, | ||
sagemaker_session=self.sagemaker_session, | ||
vpc_config=self.get_vpc_config(vpc_config_override)) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import logging | ||
|
||
import sagemaker | ||
from sagemaker import Model, RealTimePredictor | ||
from sagemaker.content_types import CONTENT_TYPE_JSON | ||
from sagemaker.fw_utils import create_image_uri | ||
from sagemaker.predictor import json_serializer, json_deserializer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: from google python style guide -> import modules not functions and classes. |
||
from sagemaker.tensorflow.defaults import TF_VERSION | ||
|
||
|
||
class TFSPredictor(RealTimePredictor): | ||
"""A ``RealTimePredictor`` implementation for inference against TFS endpoints. | ||
""" | ||
|
||
def __init__(self, endpoint_name, sagemaker_session=None, | ||
serializer=json_serializer, | ||
deserializer=json_deserializer, | ||
model_name=None, | ||
model_version=None): | ||
"""Initialize a ``TFSPredictor``. See ``sagemaker.RealTimePredictor`` for | ||
more info about parameters. | ||
|
||
Args: | ||
endpoint_name (str): The name of the endpoint to perform inference on. | ||
sagemaker_session (sagemaker.session.Session): Session object which manages interactions | ||
with Amazon SageMaker APIs and any other AWS services needed. If not specified, | ||
the estimator creates one using the default AWS configuration chain. | ||
serializer (callable): Optional. Default serializes input data to json. Handles dicts, | ||
lists, and numpy arrays. | ||
deserializer (callable): Optional. Default parses the response using ``json.load(...)``. | ||
model_name (str): Optional. The name of the TFS model that should handle the request. | ||
If not specified, the endpoint's default model will handle the request. | ||
model_version (str): Optional. The version of the TFS model that should handle the | ||
request. If not specified, the latest version of the model will be used. | ||
""" | ||
super(TFSPredictor, self).__init__(endpoint_name, sagemaker_session, serializer, | ||
deserializer) | ||
|
||
attributes = [] | ||
if model_name: | ||
attributes.append('tfs-model-name={}'.format(model_name)) | ||
if model_version: | ||
attributes.append('tfs-model-version={}'.format(model_version)) | ||
self._model_attributes = ','.join(attributes) if attributes else None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to use an empty array to represent an empty list of model attributes instead of None? |
||
|
||
def classify(self, data): | ||
return self._classify_or_regress(data, 'classify') | ||
|
||
def regress(self, data): | ||
return self._classify_or_regress(data, 'regress') | ||
|
||
def _classify_or_regress(self, data, method): | ||
if method not in ['classify', 'regress']: | ||
raise ValueError('invalid TensorFlow Serving method: {}'.format(method)) | ||
|
||
if self.content_type != CONTENT_TYPE_JSON: | ||
raise ValueError('The {} api requires json requests.'.format(method)) | ||
|
||
args = { | ||
'CustomAttributes': 'tfs-method={}'.format(method) | ||
} | ||
|
||
return self.predict(data, args) | ||
|
||
def predict(self, data, initial_args=None): | ||
args = dict(initial_args) if initial_args else {} | ||
if self._model_attributes: | ||
if 'CustomAttributes' in args: | ||
args['CustomAttributes'] += ',' + self._model_attributes | ||
else: | ||
args['CustomAttributes'] = self._model_attributes | ||
|
||
return super(TFSPredictor, self).predict(data, args) | ||
|
||
|
||
class TFSModel(Model): | ||
FRAMEWORK_NAME = 'tfs' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do believe that other classes use dunders instead of a constant, although I Iike your style better and matches the conventions https://github.com/google/styleguide/blob/gh-pages/pyguide.md#3164-guidelines-derived-from-guidos-recommendations nit: consider leaving it at module level |
||
LOG_LEVEL_PARAM_NAME = 'SAGEMAKER_TFS_NGINX_LOGLEVEL' | ||
LOG_LEVEL_MAP = { | ||
logging.DEBUG: 'debug', | ||
logging.INFO: 'info', | ||
logging.WARNING: 'warn', | ||
logging.ERROR: 'error', | ||
logging.CRITICAL: 'crit', | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dictionaries are not constants in Python. |
||
|
||
def __init__(self, model_data, role, image=None, framework_version=TF_VERSION, | ||
container_log_level=None, predictor_cls=TFSPredictor, **kwargs): | ||
"""Initialize a TFSModel. | ||
|
||
Args: | ||
model_data (str): The S3 location of a SageMaker model data ``.tar.gz`` file. | ||
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker APIs that | ||
create Amazon SageMaker endpoints use this role to access model artifacts. | ||
image (str): A Docker image URI (default: None). If not specified, a default image for | ||
TensorFlow Serving will be used. | ||
framework_version (str): Optional. TensorFlow Serving version you want to use. | ||
container_log_level (int): Log level to use within the container (default: logging.ERROR). | ||
Valid values are defined in the Python logging module. | ||
predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a | ||
predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` | ||
returns the result of invoking this function on the created endpoint name. | ||
**kwargs: Keyword arguments passed to the ``Model`` initializer. | ||
""" | ||
super(TFSModel, self).__init__(model_data=model_data, role=role, image=image, | ||
predictor_cls=predictor_cls, **kwargs) | ||
self._framework_version = framework_version | ||
self._container_log_level = container_log_level | ||
|
||
def prepare_container_def(self, instance_type): | ||
image = self._get_image_uri(instance_type) | ||
env = self._get_container_env() | ||
return sagemaker.container_def(image, self.model_data, env) | ||
|
||
def _get_container_env(self): | ||
if not self._container_log_level: | ||
return self.env | ||
|
||
if self._container_log_level not in TFSModel.LOG_LEVEL_MAP: | ||
logging.warning('ignoring invalid container log level: %s', self._container_log_level) | ||
return self.env | ||
|
||
env = dict(self.env) | ||
env['SAGEMAKER_TFS_NGINX_LOGLEVEL'] = TFSModel.LOG_LEVEL_MAP[self._container_log_level] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use a constant for this key. |
||
return env | ||
|
||
def _get_image_uri(self, instance_type): | ||
if self.image: | ||
return self.image | ||
|
||
# reuse standard image uri function, then strip unwanted python component | ||
region_name = self.sagemaker_session.boto_region_name | ||
image = create_image_uri(region_name, TFSModel.FRAMEWORK_NAME, instance_type, | ||
self._framework_version, 'py3') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How hard is to change create image uri instead of reusing and striping it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lots of fussy behavior in create_image_uri. it would be easy to fall out of sync |
||
image = image.replace('-py3', '') | ||
return image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please clarify more this docstring argument.