Skip to content

feature: support Endpoint_type for TF transform #881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 1, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,56 @@ def train_image(self):
self.train_instance_type, self.framework_version, self.py_version)

return super(TensorFlow, self).train_image()

def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
max_payload=None, tags=None, role=None, model_server_workers=None, volume_kms_key=None,
endpoint_type=None):
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
SageMaker Session and base job name used by the Estimator.

Args:
instance_count (int): Number of EC2 instances to use.
instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
strategy (str): The strategy used to decide how to batch records in a single request (default: None).
Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'.
assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'.
output_path (str): S3 location for saving the transform result. If not specified, results are stored to
a default bucket.
output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None).
accept (str): The content type accepted by the endpoint deployed during the transform job.
env (dict): Environment variables to be set for use during the transform job (default: None).
max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
each individual transform container at one time.
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
the training job are used for the transform job.
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
transform jobs. If not specified, the role from the Estimator will be used.
model_server_workers (int): Optional. The number of worker processes used by the inference server.
If None, server will use one worker per vCPU.
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
compute instance (default: None).
endpoint_type (str): Optional. Selects the software stack used by the inference server.
If not specified, the model will be configured to use the default
SageMaker model server.
If 'tensorflow-serving', the model will be configured to
use the SageMaker Tensorflow Serving container.
"""

if endpoint_type == 'tensorflow-serving':
self.script_mode = True
return super(TensorFlow, self).transformer(instance_count,
instance_type,
strategy,
assemble_with,
output_path,
output_kms_key,
accept,
env,
max_concurrent_transforms,
max_payload,
tags,
role,
model_server_workers,
volume_kms_key)
49 changes: 48 additions & 1 deletion tests/unit/test_tf_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
import pytest
from mock import patch, Mock, MagicMock

from sagemaker.estimator import _TrainingJob
from sagemaker.fw_utils import create_image_uri
from sagemaker.model import MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.session import s3_input
from sagemaker.tensorflow import defaults, TensorFlow, TensorFlowModel, TensorFlowPredictor
import sagemaker.tensorflow.estimator as tfe
from sagemaker.transformer import Transformer

DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
SCRIPT_FILE = 'dummy_script.py'
Expand Down Expand Up @@ -264,12 +266,57 @@ def test_create_model_with_optional_params(sagemaker_session):
vpc_config = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']}
model = tf.create_model(role=new_role, model_server_workers=model_server_workers,
vpc_config_override=vpc_config)

assert model.role == new_role
assert model.model_server_workers == model_server_workers
assert model.vpc_config == vpc_config


@patch('sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model')
def test_transformer_creation_with_endpoint_type(create_tfs_model, sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = 's3://mybucket/source'
enable_cloudwatch_metrics = 'true'
base_name = 'foo'
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE, container_log_level=container_log_level, base_job_name=base_name,
source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)
tf.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of this line?

assert isinstance(tf, TensorFlow)
transformer = tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE, endpoint_type='tensorflow-serving')
create_tfs_model.assert_called_once()
assert isinstance(transformer, Transformer)
assert transformer.sagemaker_session == sagemaker_session
assert transformer.instance_count == INSTANCE_COUNT
assert transformer.instance_type == INSTANCE_TYPE
assert transformer.tags is None
assert tf.script_mode is True
assert tf._script_mode_enabled() is True


@patch('sagemaker.tensorflow.estimator.TensorFlow._create_default_model')
def test_transformer_creation_without_endpoint_type(create_default_model, sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = 's3://mybucket/source'
enable_cloudwatch_metrics = 'true'
base_name = 'flo'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if these aren't being checked later, then there's no need to define them. I'd recommend specifying only the required args for TensorFlow

tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE, container_log_level=container_log_level, base_job_name=base_name,
source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)
tf.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
assert isinstance(tf, TensorFlow)
transformer = tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
create_default_model.assert_called_once()
assert isinstance(transformer, Transformer)
assert transformer.sagemaker_session == sagemaker_session
assert transformer.instance_count == INSTANCE_COUNT
assert transformer.instance_type == INSTANCE_TYPE
assert transformer.tags is None
assert tf.script_mode is False
assert tf._script_mode_enabled() is False


def test_create_model_with_custom_image(sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = 's3://mybucket/source'
Expand Down