Skip to content

Commit 356c5d1

Browse files
committed
change styling
1 parent 0fa2a13 commit 356c5d1

File tree

2 files changed

+58
-33
lines changed

2 files changed

+58
-33
lines changed

src/sagemaker/tensorflow/estimator.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -594,10 +594,24 @@ def train_image(self):
594594

595595
return super(TensorFlow, self).train_image()
596596

597-
def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
598-
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
599-
max_payload=None, tags=None, role=None, model_server_workers=None, volume_kms_key=None,
600-
endpoint_type=None):
597+
def transformer(
598+
self,
599+
instance_count,
600+
instance_type,
601+
strategy=None,
602+
assemble_with=None,
603+
output_path=None,
604+
output_kms_key=None,
605+
accept=None,
606+
env=None,
607+
max_concurrent_transforms=None,
608+
max_payload=None,
609+
tags=None,
610+
role=None,
611+
model_server_workers=None,
612+
volume_kms_key=None,
613+
endpoint_type=None,
614+
):
601615
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
602616
SageMaker Session and base job name used by the Estimator.
603617
@@ -630,19 +644,21 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
630644
use the SageMaker Tensorflow Serving container.
631645
"""
632646

633-
if endpoint_type == 'tensorflow-serving':
647+
if endpoint_type == "tensorflow-serving":
634648
self.script_mode = True
635-
return super(TensorFlow, self).transformer(instance_count,
636-
instance_type,
637-
strategy,
638-
assemble_with,
639-
output_path,
640-
output_kms_key,
641-
accept,
642-
env,
643-
max_concurrent_transforms,
644-
max_payload,
645-
tags,
646-
role,
647-
model_server_workers,
648-
volume_kms_key)
649+
return super(TensorFlow, self).transformer(
650+
instance_count,
651+
instance_type,
652+
strategy,
653+
assemble_with,
654+
output_path,
655+
output_kms_key,
656+
accept,
657+
env,
658+
max_concurrent_transforms,
659+
max_payload,
660+
tags,
661+
role,
662+
model_server_workers,
663+
volume_kms_key,
664+
)

tests/unit/test_tf_estimator.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -298,48 +298,57 @@ def test_create_model_with_optional_params(sagemaker_session):
298298
new_role = "role"
299299
model_server_workers = 2
300300

301-
vpc_config = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']}
302-
model = tf.create_model(role=new_role, model_server_workers=model_server_workers,
303-
vpc_config_override=vpc_config)
301+
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
302+
model = tf.create_model(
303+
role=new_role, model_server_workers=model_server_workers, vpc_config_override=vpc_config
304+
)
304305

305306
assert model.role == new_role
306307
assert model.model_server_workers == model_server_workers
307308
assert model.vpc_config == vpc_config
308309

309310

310-
@patch('sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model')
311+
@patch("sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model")
311312
def test_transformer_creation_with_endpoint_type(create_tfs_model, sagemaker_session):
312-
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
313-
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE)
313+
tf = TensorFlow(
314+
entry_point=SCRIPT_PATH,
315+
role=ROLE,
316+
sagemaker_session=sagemaker_session,
317+
train_instance_count=INSTANCE_COUNT,
318+
train_instance_type=INSTANCE_TYPE,
319+
)
314320

315321
tf.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
316-
transformer = tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE, endpoint_type='tensorflow-serving')
322+
transformer = tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE, endpoint_type="tensorflow-serving")
317323
assert isinstance(transformer, Transformer)
318-
create_tfs_model.assert_called_once()
319324
assert transformer.sagemaker_session == sagemaker_session
320325
assert transformer.instance_count == INSTANCE_COUNT
321326
assert transformer.instance_type == INSTANCE_TYPE
322-
assert transformer.tags is None
323327
assert tf.script_mode is True
324328
assert tf._script_mode_enabled() is True
329+
create_tfs_model.assert_called_once()
325330

326331

327-
@patch('sagemaker.tensorflow.estimator.TensorFlow._create_default_model')
332+
@patch("sagemaker.tensorflow.estimator.TensorFlow._create_default_model")
328333
def test_transformer_creation_without_endpoint_type(create_default_model, sagemaker_session):
329334

330-
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
331-
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE)
335+
tf = TensorFlow(
336+
entry_point=SCRIPT_PATH,
337+
role=ROLE,
338+
sagemaker_session=sagemaker_session,
339+
train_instance_count=INSTANCE_COUNT,
340+
train_instance_type=INSTANCE_TYPE,
341+
)
332342

333343
tf.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
334344
transformer = tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
335345
assert isinstance(transformer, Transformer)
336-
create_default_model.assert_called_once()
337346
assert transformer.sagemaker_session == sagemaker_session
338347
assert transformer.instance_count == INSTANCE_COUNT
339348
assert transformer.instance_type == INSTANCE_TYPE
340-
assert transformer.tags is None
341349
assert tf.script_mode is False
342350
assert tf._script_mode_enabled() is False
351+
create_default_model.assert_called_once()
343352

344353

345354
def test_create_model_with_custom_image(sagemaker_session):

0 commit comments

Comments
 (0)