Skip to content

Commit 52292c5

Browse files
committed
change: minor changes based on feedback
1 parent e687974 commit 52292c5

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

src/sagemaker/tensorflow/estimator.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -531,15 +531,26 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
531531
If None, server will use one worker per vCPU.
532532
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
533533
compute instance (default: None).
534-
endpoint_type: Optional. Selects the software stack used by the inference server.
535-
If not specified, the model will be configured to use the default
536-
SageMaker model server. If 'tensorflow-serving', the model will be configured to
534+
endpoint_type (str): Optional. Selects the software stack used by the inference server.
535+
If not specified, the model will be configured to use the default
536+
SageMaker model server.
537+
If 'tensorflow-serving', the model will be configured to
537538
use the SageMaker Tensorflow Serving container.
538539
"""
539540

540541
if endpoint_type == 'tensorflow-serving':
541542
self.script_mode = True
542-
return super(TensorFlow, self).transformer(instance_count, instance_type, strategy, assemble_with, output_path,
543-
output_kms_key, accept, env, max_concurrent_transforms, max_payload,
544-
tags, role, model_server_workers, volume_kms_key
545-
)
543+
return super(TensorFlow, self).transformer(instance_count,
544+
instance_type,
545+
strategy,
546+
assemble_with,
547+
output_path,
548+
output_kms_key,
549+
accept,
550+
env,
551+
max_concurrent_transforms,
552+
max_payload,
553+
tags,
554+
role,
555+
model_server_workers,
556+
volume_kms_key)

tests/unit/test_tf_estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
import pytest
2020
from mock import patch, Mock, MagicMock
2121

22+
from sagemaker.estimator import _TrainingJob
2223
from sagemaker.fw_utils import create_image_uri
2324
from sagemaker.model import MODEL_SERVER_WORKERS_PARAM_NAME
2425
from sagemaker.session import s3_input
2526
from sagemaker.tensorflow import defaults, TensorFlow, TensorFlowModel, TensorFlowPredictor
26-
from sagemaker.estimator import _TrainingJob
2727
import sagemaker.tensorflow.estimator as tfe
2828
from sagemaker.transformer import Transformer
2929

@@ -293,6 +293,7 @@ def test_transformer_creation_with_endpoint_type(create_tfs_model, sagemaker_ses
293293
assert tf.script_mode is True
294294
assert tf._script_mode_enabled() is True
295295

296+
296297
@patch('sagemaker.tensorflow.estimator.TensorFlow._create_default_model')
297298
def test_transformer_creation_without_endpoint_type(create_default_model, sagemaker_session):
298299
container_log_level = '"logging.INFO"'

0 commit comments

Comments
 (0)