Skip to content

Commit e3a34ea

Browse files
committed
Add endpoint_type support for TF transform
1 parent 0bd044c commit e3a34ea

File tree

2 files changed

+89
-1
lines changed

2 files changed

+89
-1
lines changed

src/sagemaker/tensorflow/estimator.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,3 +501,45 @@ def train_image(self):
501501
self.train_instance_type, self.framework_version, self.py_version)
502502

503503
return super(TensorFlow, self).train_image()
504+
505+
def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
506+
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
507+
max_payload=None, tags=None, role=None, model_server_workers=None, volume_kms_key=None,
508+
endpoint_type=None):
509+
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
510+
SageMaker Session and base job name used by the Estimator.
511+
512+
Args:
513+
instance_count (int): Number of EC2 instances to use.
514+
instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
515+
strategy (str): The strategy used to decide how to batch records in a single request (default: None).
516+
Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'.
517+
assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'.
518+
output_path (str): S3 location for saving the transform result. If not specified, results are stored to
519+
a default bucket.
520+
output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None).
521+
accept (str): The content type accepted by the endpoint deployed during the transform job.
522+
env (dict): Environment variables to be set for use during the transform job (default: None).
523+
max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
524+
each individual transform container at one time.
525+
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
526+
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
527+
the training job are used for the transform job.
528+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
529+
transform jobs. If not specified, the role from the Estimator will be used.
530+
model_server_workers (int): Optional. The number of worker processes used by the inference server.
531+
If None, server will use one worker per vCPU.
532+
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
533+
compute instance (default: None).
534+
endpoint_type: Optional. Selects the software stack used by the inference server.
535+
If not specified, the model will be configured to use the default
536+
SageMaker model server. If 'tensorflow-serving', the model will be configured to
537+
use the SageMaker Tensorflow Serving container.
538+
"""
539+
540+
if endpoint_type == 'tensorflow-serving':
541+
self.script_mode = True
542+
return super(TensorFlow, self).transformer(instance_count, instance_type, strategy, assemble_with, output_path,
543+
output_kms_key, accept, env, max_concurrent_transforms, max_payload,
544+
tags, role, model_server_workers, volume_kms_key
545+
)

tests/unit/test_tf_estimator.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from sagemaker.model import MODEL_SERVER_WORKERS_PARAM_NAME
2424
from sagemaker.session import s3_input
2525
from sagemaker.tensorflow import defaults, TensorFlow, TensorFlowModel, TensorFlowPredictor
26+
from sagemaker.estimator import _TrainingJob
2627
import sagemaker.tensorflow.estimator as tfe
28+
from sagemaker.transformer import Transformer
2729

2830
DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
2931
SCRIPT_FILE = 'dummy_script.py'
@@ -264,12 +266,56 @@ def test_create_model_with_optional_params(sagemaker_session):
264266
vpc_config = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']}
265267
model = tf.create_model(role=new_role, model_server_workers=model_server_workers,
266268
vpc_config_override=vpc_config)
267-
268269
assert model.role == new_role
269270
assert model.model_server_workers == model_server_workers
270271
assert model.vpc_config == vpc_config
271272

272273

274+
@patch('sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model')
275+
def test_transformer_creation_with_endpoint_type(create_tfs_model, sagemaker_session):
276+
container_log_level = '"logging.INFO"'
277+
source_dir = 's3://mybucket/source'
278+
enable_cloudwatch_metrics = 'true'
279+
base_name = 'foo'
280+
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
281+
training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT,
282+
train_instance_type=INSTANCE_TYPE, container_log_level=container_log_level, base_job_name=base_name,
283+
source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)
284+
tf.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
285+
assert isinstance(tf, TensorFlow)
286+
transformer = tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE, endpoint_type='tensorflow-serving')
287+
create_tfs_model.assert_called_once()
288+
assert isinstance(transformer, Transformer)
289+
assert transformer.sagemaker_session == sagemaker_session
290+
assert transformer.instance_count == INSTANCE_COUNT
291+
assert transformer.instance_type == INSTANCE_TYPE
292+
assert transformer.tags is None
293+
assert tf.script_mode is True
294+
assert tf._script_mode_enabled() is True
295+
296+
@patch('sagemaker.tensorflow.estimator.TensorFlow._create_default_model')
297+
def test_transformer_creation_without_endpoint_type(create_default_model, sagemaker_session):
298+
container_log_level = '"logging.INFO"'
299+
source_dir = 's3://mybucket/source'
300+
enable_cloudwatch_metrics = 'true'
301+
base_name = 'flo'
302+
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
303+
training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT,
304+
train_instance_type=INSTANCE_TYPE, container_log_level=container_log_level, base_job_name=base_name,
305+
source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)
306+
tf.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
307+
assert isinstance(tf, TensorFlow)
308+
transformer = tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
309+
create_default_model.assert_called_once()
310+
assert isinstance(transformer, Transformer)
311+
assert transformer.sagemaker_session == sagemaker_session
312+
assert transformer.instance_count == INSTANCE_COUNT
313+
assert transformer.instance_type == INSTANCE_TYPE
314+
assert transformer.tags is None
315+
assert tf.script_mode is False
316+
assert tf._script_mode_enabled() is False
317+
318+
273319
def test_create_model_with_custom_image(sagemaker_session):
274320
container_log_level = '"logging.INFO"'
275321
source_dir = 's3://mybucket/source'

0 commit comments

Comments
 (0)