Skip to content

Commit 3fcc5d7

Browse files
committed
black check
1 parent 9b316d3 commit 3fcc5d7

File tree

4 files changed

+30
-16
lines changed

4 files changed

+30
-16
lines changed

tests/integ/test_tf_script_mode.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def test_mnist_async(sagemaker_session):
159159
training_job_name = estimator.latest_training_job.name
160160
time.sleep(20)
161161
endpoint_name = training_job_name
162-
model_name = 'model-name-1'
162+
model_name = "model-name-1"
163163
_assert_training_job_tags_match(
164164
sagemaker_session.sagemaker_client, estimator.latest_training_job.name, TAGS
165165
)
@@ -168,8 +168,10 @@ def test_mnist_async(sagemaker_session):
168168
training_job_name=training_job_name, sagemaker_session=sagemaker_session
169169
)
170170
predictor = estimator.deploy(
171-
initial_instance_count=1, instance_type="ml.c4.xlarge", endpoint_name=endpoint_name,
172-
model_name=model_name
171+
initial_instance_count=1,
172+
instance_type="ml.c4.xlarge",
173+
endpoint_name=endpoint_name,
174+
model_name=model_name,
173175
)
174176

175177
result = predictor.predict(np.zeros(784))
@@ -178,9 +180,7 @@ def test_mnist_async(sagemaker_session):
178180
_assert_model_tags_match(
179181
sagemaker_session.sagemaker_client, estimator.latest_training_job.name, TAGS
180182
)
181-
_assert_model_name_match(
182-
sagemaker_session.sagemaker_client, endpoint_name, model_name
183-
)
183+
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)
184184

185185

186186
def test_deploy_with_input_handlers(sagemaker_session, instance_type):
@@ -252,4 +252,4 @@ def _assert_model_name_match(sagemaker_client, endpoint_config_name, model_name)
252252
endpoint_config_description = sagemaker_client.describe_endpoint_config(
253253
EndpointConfigName=endpoint_config_name
254254
)
255-
assert model_name == endpoint_config_description['ProductionVariants'][0]['ModelName']
255+
assert model_name == endpoint_config_description["ProductionVariants"][0]["ModelName"]

tests/integ/test_tuner.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ def test_attach_tuning_pytorch(sagemaker_session):
844844
tuner.wait()
845845

846846
endpoint_name = tuning_job_name
847-
model_name = 'model-name-1'
847+
model_name = "model-name-1"
848848
attached_tuner = HyperparameterTuner.attach(
849849
tuning_job_name, sagemaker_session=sagemaker_session
850850
)
@@ -948,5 +948,7 @@ def _fm_serializer(data):
948948

949949

950950
def _assert_model_name_match(sagemaker_client, endpoint_config_name, model_name):
951-
endpoint_config_description = sagemaker_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
952-
assert model_name == endpoint_config_description['ProductionVariants'][0]['ModelName']
951+
endpoint_config_description = sagemaker_client.describe_endpoint_config(
952+
EndpointConfigName=endpoint_config_name
953+
)
954+
assert model_name == endpoint_config_description["ProductionVariants"][0]["ModelName"]

tests/unit/test_estimator.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,8 +1712,14 @@ def test_deploy_with_update_endpoint(sagemaker_session):
17121712

17131713

17141714
def test_deploy_with_model_name(sagemaker_session):
1715-
estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
1716-
sagemaker_session=sagemaker_session)
1715+
estimator = Estimator(
1716+
IMAGE_NAME,
1717+
ROLE,
1718+
INSTANCE_COUNT,
1719+
INSTANCE_TYPE,
1720+
output_path=OUTPUT_PATH,
1721+
sagemaker_session=sagemaker_session,
1722+
)
17171723
estimator.set_hyperparameters(**HYPERPARAMS)
17181724
estimator.fit({"train": "s3://bucket/training-prefix"})
17191725
model_name = "model-name"
@@ -1725,10 +1731,16 @@ def test_deploy_with_model_name(sagemaker_session):
17251731

17261732

17271733
def test_deploy_with_no_model_name(sagemaker_session):
1728-
estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
1729-
sagemaker_session=sagemaker_session)
1734+
estimator = Estimator(
1735+
IMAGE_NAME,
1736+
ROLE,
1737+
INSTANCE_COUNT,
1738+
INSTANCE_TYPE,
1739+
output_path=OUTPUT_PATH,
1740+
sagemaker_session=sagemaker_session,
1741+
)
17301742
estimator.set_hyperparameters(**HYPERPARAMS)
1731-
estimator.fit({'train': 's3://bucket/training-prefix'})
1743+
estimator.fit({"train": "s3://bucket/training-prefix"})
17321744
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE)
17331745

17341746
sagemaker_session.create_model.assert_called_once()

tests/unit/test_tuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ def test_deploy_default(tuner):
646646
tuner.estimator.sagemaker_session.create_model.assert_called_once()
647647
args = tuner.estimator.sagemaker_session.create_model.call_args[0]
648648

649-
assert args[0] == 'neo'
649+
assert args[0] == "neo"
650650
assert args[1] == ROLE
651651
assert args[2]["Image"] == IMAGE_NAME
652652
assert args[2]["ModelDataUrl"] == MODEL_DATA

0 commit comments

Comments
 (0)