Skip to content

Commit 9d8a29a

Browse files
caxiaohumvsusp
authored andcommitted
feature: support Endpoint_type for TF transform (#881)
1 parent 4b8623e commit 9d8a29a

File tree

2 files changed

+146
-0
lines changed

2 files changed

+146
-0
lines changed

src/sagemaker/tensorflow/estimator.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,3 +593,75 @@ def train_image(self):
593593
)
594594

595595
return super(TensorFlow, self).train_image()
596+
597+
def transformer(
598+
self,
599+
instance_count,
600+
instance_type,
601+
strategy=None,
602+
assemble_with=None,
603+
output_path=None,
604+
output_kms_key=None,
605+
accept=None,
606+
env=None,
607+
max_concurrent_transforms=None,
608+
max_payload=None,
609+
tags=None,
610+
role=None,
611+
model_server_workers=None,
612+
volume_kms_key=None,
613+
endpoint_type=None,
614+
):
615+
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
616+
SageMaker Session and base job name used by the Estimator.
617+
618+
Args:
619+
instance_count (int): Number of EC2 instances to use.
620+
instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
621+
strategy (str): The strategy used to decide how to batch records in a single request (default: None).
622+
Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'.
623+
assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'.
624+
output_path (str): S3 location for saving the transform result. If not specified, results are stored to
625+
a default bucket.
626+
output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None).
627+
accept (str): The content type accepted by the endpoint deployed during the transform job.
628+
env (dict): Environment variables to be set for use during the transform job (default: None).
629+
max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
630+
each individual transform container at one time.
631+
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
632+
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
633+
the training job are used for the transform job.
634+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
635+
transform jobs. If not specified, the role from the Estimator will be used.
636+
model_server_workers (int): Optional. The number of worker processes used by the inference server.
637+
If None, server will use one worker per vCPU.
638+
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
639+
compute instance (default: None).
640+
endpoint_type (str): Optional. Selects the software stack used by the inference server.
641+
If not specified, the model will be configured to use the default
642+
SageMaker model server.
643+
If 'tensorflow-serving', the model will be configured to
644+
use the SageMaker Tensorflow Serving container.
645+
"""
646+
647+
role = role or self.role
648+
model = self.create_model(
649+
model_server_workers=model_server_workers,
650+
role=role,
651+
vpc_config_override=VPC_CONFIG_DEFAULT,
652+
endpoint_type=endpoint_type,
653+
)
654+
return model.transformer(
655+
instance_count,
656+
instance_type,
657+
strategy=strategy,
658+
assemble_with=assemble_with,
659+
output_path=output_path,
660+
output_kms_key=output_kms_key,
661+
accept=accept,
662+
env=env,
663+
max_concurrent_transforms=max_concurrent_transforms,
664+
max_payload=max_payload,
665+
tags=tags,
666+
volume_kms_key=volume_kms_key,
667+
)

tests/unit/test_tf_estimator.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sagemaker.tensorflow import defaults, TensorFlow, TensorFlowModel, TensorFlowPredictor
2626
import sagemaker.tensorflow.estimator as tfe
2727

28+
2829
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
2930
SCRIPT_FILE = "dummy_script.py"
3031
SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_FILE)
@@ -305,6 +306,79 @@ def test_create_model_with_optional_params(sagemaker_session):
305306
assert model.vpc_config == vpc_config
306307

307308

309+
@patch("sagemaker.tensorflow.estimator.TensorFlow.create_model")
310+
def test_transformer_creation_with_endpoint_type(create_model, sagemaker_session):
311+
model = Mock()
312+
create_model.return_value = model
313+
314+
tf = TensorFlow(
315+
entry_point=SCRIPT_PATH,
316+
role=ROLE,
317+
sagemaker_session=sagemaker_session,
318+
train_instance_count=INSTANCE_COUNT,
319+
train_instance_type=INSTANCE_TYPE,
320+
)
321+
322+
tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE, endpoint_type="tensorflow-serving")
323+
324+
create_model.assert_called_with(
325+
endpoint_type="tensorflow-serving",
326+
model_server_workers=None,
327+
role=ROLE,
328+
vpc_config_override="VPC_CONFIG_DEFAULT",
329+
)
330+
model.transformer.assert_called_with(
331+
INSTANCE_COUNT,
332+
INSTANCE_TYPE,
333+
accept=None,
334+
assemble_with=None,
335+
env=None,
336+
max_concurrent_transforms=None,
337+
max_payload=None,
338+
output_kms_key=None,
339+
output_path=None,
340+
strategy=None,
341+
tags=None,
342+
volume_kms_key=None,
343+
)
344+
345+
346+
@patch("sagemaker.tensorflow.estimator.TensorFlow.create_model")
347+
def test_transformer_creation_without_endpoint_type(create_model, sagemaker_session):
348+
model = Mock()
349+
create_model.return_value = model
350+
351+
tf = TensorFlow(
352+
entry_point=SCRIPT_PATH,
353+
role=ROLE,
354+
sagemaker_session=sagemaker_session,
355+
train_instance_count=INSTANCE_COUNT,
356+
train_instance_type=INSTANCE_TYPE,
357+
)
358+
tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
359+
360+
create_model.assert_called_with(
361+
endpoint_type=None,
362+
model_server_workers=None,
363+
role=ROLE,
364+
vpc_config_override="VPC_CONFIG_DEFAULT",
365+
)
366+
model.transformer.assert_called_with(
367+
INSTANCE_COUNT,
368+
INSTANCE_TYPE,
369+
accept=None,
370+
assemble_with=None,
371+
env=None,
372+
max_concurrent_transforms=None,
373+
max_payload=None,
374+
output_kms_key=None,
375+
output_path=None,
376+
strategy=None,
377+
tags=None,
378+
volume_kms_key=None,
379+
)
380+
381+
308382
def test_create_model_with_custom_image(sagemaker_session):
309383
container_log_level = '"logging.INFO"'
310384
source_dir = "s3://mybucket/source"

0 commit comments

Comments
 (0)