Skip to content

Commit 8b02fbb

Browse files
committed
change: refactor endpoint support for TF transformer
1 parent a468d36 commit 8b02fbb

File tree

2 files changed

+36
-36
lines changed

2 files changed

+36
-36
lines changed

src/sagemaker/tensorflow/estimator.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -644,21 +644,27 @@ def transformer(
644644
use the SageMaker Tensorflow Serving container.
645645
"""
646646

647-
if endpoint_type == "tensorflow-serving":
648-
self.script_mode = True
649-
return super(TensorFlow, self).transformer(
647+
role = role or self.role
648+
model = self.create_model(
649+
model_server_workers=model_server_workers,
650+
role=role,
651+
vpc_config_override=VPC_CONFIG_DEFAULT,
652+
endpoint_type=endpoint_type,
653+
)
654+
return model.transformer(
650655
instance_count,
651656
instance_type,
652-
strategy,
653-
assemble_with,
654-
output_path,
655-
output_kms_key,
656-
accept,
657-
env,
658-
max_concurrent_transforms,
659-
max_payload,
660-
tags,
661-
role,
662-
model_server_workers,
663-
volume_kms_key,
657+
strategy=strategy,
658+
assemble_with=assemble_with,
659+
output_path=output_path,
660+
output_kms_key=output_kms_key,
661+
accept=accept,
662+
env=env,
663+
max_concurrent_transforms=max_concurrent_transforms,
664+
max_payload=max_payload,
665+
tags=None,
666+
volume_kms_key=volume_kms_key,
667+
664668
)
669+
670+

tests/unit/test_tf_estimator.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -308,29 +308,26 @@ def test_create_model_with_optional_params(sagemaker_session):
308308
assert model.vpc_config == vpc_config
309309

310310

311-
@patch("sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model")
312-
def test_transformer_creation_with_endpoint_type(create_tfs_model, sagemaker_session):
311+
@patch("sagemaker.tensorflow.estimator.TensorFlow.create_model")
312+
def test_transformer_creation_with_endpoint_type(create_model, sagemaker_session):
313313
tf = TensorFlow(
314314
entry_point=SCRIPT_PATH,
315315
role=ROLE,
316316
sagemaker_session=sagemaker_session,
317317
train_instance_count=INSTANCE_COUNT,
318318
train_instance_type=INSTANCE_TYPE,
319319
)
320-
321320
tf.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
322-
transformer = tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE, endpoint_type="tensorflow-serving")
323-
assert isinstance(transformer, Transformer)
324-
assert transformer.sagemaker_session == sagemaker_session
325-
assert transformer.instance_count == INSTANCE_COUNT
326-
assert transformer.instance_type == INSTANCE_TYPE
327-
assert tf.script_mode is True
328-
assert tf._script_mode_enabled() is True
329-
create_tfs_model.assert_called_once()
321+
322+
tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE, model_server_workers=2, endpoint_type="tensorflow-serving")
323+
create_model.assert_called_with(endpoint_type='tensorflow-serving',
324+
model_server_workers=2,
325+
role='Dummy',
326+
vpc_config_override='VPC_CONFIG_DEFAULT')
330327

331328

332-
@patch("sagemaker.tensorflow.estimator.TensorFlow._create_default_model")
333-
def test_transformer_creation_without_endpoint_type(create_default_model, sagemaker_session):
329+
@patch("sagemaker.tensorflow.estimator.TensorFlow.create_model")
330+
def test_transformer_creation_without_endpoint_type(create_model, sagemaker_session):
334331

335332
tf = TensorFlow(
336333
entry_point=SCRIPT_PATH,
@@ -341,14 +338,11 @@ def test_transformer_creation_without_endpoint_type(create_default_model, sagema
341338
)
342339

343340
tf.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
344-
transformer = tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
345-
assert isinstance(transformer, Transformer)
346-
assert transformer.sagemaker_session == sagemaker_session
347-
assert transformer.instance_count == INSTANCE_COUNT
348-
assert transformer.instance_type == INSTANCE_TYPE
349-
assert tf.script_mode is False
350-
assert tf._script_mode_enabled() is False
351-
create_default_model.assert_called_once()
341+
transformer = tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE, model_server_workers=4)
342+
create_model.assert_called_with(endpoint_type=None,
343+
model_server_workers=4,
344+
role='Dummy',
345+
vpc_config_override='VPC_CONFIG_DEFAULT')
352346

353347

354348
def test_create_model_with_custom_image(sagemaker_session):

0 commit comments

Comments
 (0)