Skip to content

feature: support Endpoint_type for TF transform #881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,3 +593,75 @@ def train_image(self):
)

return super(TensorFlow, self).train_image()

def transformer(
self,
instance_count,
instance_type,
strategy=None,
assemble_with=None,
output_path=None,
output_kms_key=None,
accept=None,
env=None,
max_concurrent_transforms=None,
max_payload=None,
tags=None,
role=None,
model_server_workers=None,
volume_kms_key=None,
endpoint_type=None,
):
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
SageMaker Session and base job name used by the Estimator.

Args:
instance_count (int): Number of EC2 instances to use.
instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
strategy (str): The strategy used to decide how to batch records in a single request (default: None).
Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'.
assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'.
output_path (str): S3 location for saving the transform result. If not specified, results are stored to
a default bucket.
output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None).
accept (str): The content type accepted by the endpoint deployed during the transform job.
env (dict): Environment variables to be set for use during the transform job (default: None).
max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
each individual transform container at one time.
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
the training job are used for the transform job.
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
transform jobs. If not specified, the role from the Estimator will be used.
model_server_workers (int): Optional. The number of worker processes used by the inference server.
If None, server will use one worker per vCPU.
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
compute instance (default: None).
endpoint_type (str): Optional. Selects the software stack used by the inference server.
If not specified, the model will be configured to use the default
SageMaker model server.
If 'tensorflow-serving', the model will be configured to
use the SageMaker Tensorflow Serving container.
"""

role = role or self.role
model = self.create_model(
model_server_workers=model_server_workers,
role=role,
vpc_config_override=VPC_CONFIG_DEFAULT,
endpoint_type=endpoint_type,
)
return model.transformer(
instance_count,
instance_type,
strategy=strategy,
assemble_with=assemble_with,
output_path=output_path,
output_kms_key=output_kms_key,
accept=accept,
env=env,
max_concurrent_transforms=max_concurrent_transforms,
max_payload=max_payload,
tags=tags,
volume_kms_key=volume_kms_key,
)
74 changes: 74 additions & 0 deletions tests/unit/test_tf_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sagemaker.tensorflow import defaults, TensorFlow, TensorFlowModel, TensorFlowPredictor
import sagemaker.tensorflow.estimator as tfe


DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
SCRIPT_FILE = "dummy_script.py"
SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_FILE)
Expand Down Expand Up @@ -305,6 +306,79 @@ def test_create_model_with_optional_params(sagemaker_session):
assert model.vpc_config == vpc_config


@patch("sagemaker.tensorflow.estimator.TensorFlow.create_model")
def test_transformer_creation_with_endpoint_type(create_model, sagemaker_session):
model = Mock()
create_model.return_value = model

tf = TensorFlow(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
)

tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE, endpoint_type="tensorflow-serving")

create_model.assert_called_with(
endpoint_type="tensorflow-serving",
model_server_workers=None,
role=ROLE,
vpc_config_override="VPC_CONFIG_DEFAULT",
)
model.transformer.assert_called_with(
INSTANCE_COUNT,
INSTANCE_TYPE,
accept=None,
assemble_with=None,
env=None,
max_concurrent_transforms=None,
max_payload=None,
output_kms_key=None,
output_path=None,
strategy=None,
tags=None,
volume_kms_key=None,
)


@patch("sagemaker.tensorflow.estimator.TensorFlow.create_model")
def test_transformer_creation_without_endpoint_type(create_model, sagemaker_session):
model = Mock()
create_model.return_value = model

tf = TensorFlow(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
)
tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE)

create_model.assert_called_with(
endpoint_type=None,
model_server_workers=None,
role=ROLE,
vpc_config_override="VPC_CONFIG_DEFAULT",
)
model.transformer.assert_called_with(
INSTANCE_COUNT,
INSTANCE_TYPE,
accept=None,
assemble_with=None,
env=None,
max_concurrent_transforms=None,
max_payload=None,
output_kms_key=None,
output_path=None,
strategy=None,
tags=None,
volume_kms_key=None,
)


def test_create_model_with_custom_image(sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = "s3://mybucket/source"
Expand Down