Skip to content

Commit cf92c0d

Browse files
committed
python3.10 -m black -l 100
1 parent aac0532 commit cf92c0d

File tree

5 files changed

+28
-20
lines changed

5 files changed

+28
-20
lines changed

src/sagemaker/estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,6 @@ def _prepare_for_training(self, job_name=None):
741741
self.dependencies = updated_paths["dependencies"]
742742

743743
if self.source_dir or self.entry_point or self.dependencies:
744-
745744
# validate source dir will raise a ValueError if there is something wrong with
746745
# the source directory. We are intentionally not handling it because this is a
747746
# critical error.

tests/conftest.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,9 @@ def huggingface_pytorch_training_py_version(huggingface_pytorch_training_version
277277

278278

279279
@pytest.fixture(scope="module")
280-
def huggingface_training_compiler_pytorch_version(huggingface_training_compiler_version,):
280+
def huggingface_training_compiler_pytorch_version(
281+
huggingface_training_compiler_version,
282+
):
281283
versions = _huggingface_base_fm_version(
282284
huggingface_training_compiler_version, "pytorch", "huggingface_training_compiler"
283285
)
@@ -290,7 +292,9 @@ def huggingface_training_compiler_pytorch_version(huggingface_training_compiler_
290292

291293

292294
@pytest.fixture(scope="module")
293-
def huggingface_training_compiler_tensorflow_version(huggingface_training_compiler_version,):
295+
def huggingface_training_compiler_tensorflow_version(
296+
huggingface_training_compiler_version,
297+
):
294298
versions = _huggingface_base_fm_version(
295299
huggingface_training_compiler_version, "tensorflow", "huggingface_training_compiler"
296300
)
@@ -321,14 +325,18 @@ def huggingface_training_compiler_pytorch_py_version(
321325

322326

323327
@pytest.fixture(scope="module")
324-
def huggingface_pytorch_latest_training_py_version(huggingface_training_pytorch_latest_version,):
328+
def huggingface_pytorch_latest_training_py_version(
329+
huggingface_training_pytorch_latest_version,
330+
):
325331
return (
326332
"py38" if Version(huggingface_training_pytorch_latest_version) >= Version("1.9") else "py36"
327333
)
328334

329335

330336
@pytest.fixture(scope="module")
331-
def huggingface_pytorch_latest_inference_py_version(huggingface_inference_pytorch_latest_version,):
337+
def huggingface_pytorch_latest_inference_py_version(
338+
huggingface_inference_pytorch_latest_version,
339+
):
332340
return (
333341
"py38"
334342
if Version(huggingface_inference_pytorch_latest_version) >= Version("1.9")

tests/integ/test_tf.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,11 @@ def test_mnist_with_checkpoint_config(
118118
actual_training_checkpoint_config = sagemaker_session.sagemaker_client.describe_training_job(
119119
TrainingJobName=training_job_name
120120
)["CheckpointConfig"]
121-
actual_training_environment_variable_config = sagemaker_session.sagemaker_client.describe_training_job(
122-
TrainingJobName=training_job_name
123-
)[
124-
"Environment"
125-
]
121+
actual_training_environment_variable_config = (
122+
sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)[
123+
"Environment"
124+
]
125+
)
126126

127127
expected_retry_strategy = {"MaximumRetryAttempts": 2}
128128
actual_retry_strategy = sagemaker_session.sagemaker_client.describe_training_job(

tests/integ/test_training_compiler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def test_pytorch(
149149
Test the PyTorch estimator
150150
"""
151151
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
152-
153152
hf = PyTorch(
154153
py_version="py38",
155154
source_dir=os.path.join(DATA_DIR, "huggingface_byoc"),

tests/unit/test_estimator.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,9 @@ def test_framework_with_disable_framework_metrics(sagemaker_session):
11631163
assert "profiler_rule_configs" not in args
11641164

11651165

1166-
def test_framework_with_disable_framework_metrics_and_update_system_metrics(sagemaker_session,):
1166+
def test_framework_with_disable_framework_metrics_and_update_system_metrics(
1167+
sagemaker_session,
1168+
):
11671169
f = DummyFramework(
11681170
entry_point=SCRIPT_PATH,
11691171
role=ROLE,
@@ -1183,7 +1185,9 @@ def test_framework_with_disable_framework_metrics_and_update_system_metrics(sage
11831185
assert "profiler_rule_configs" not in args
11841186

11851187

1186-
def test_framework_with_disable_framework_metrics_and_update_framework_params(sagemaker_session,):
1188+
def test_framework_with_disable_framework_metrics_and_update_framework_params(
1189+
sagemaker_session,
1190+
):
11871191
with pytest.raises(ValueError) as error:
11881192
f = DummyFramework(
11891193
entry_point=SCRIPT_PATH,
@@ -3753,7 +3757,6 @@ def test_prepare_init_params_from_job_description_with_training_image_config():
37533757

37543758

37553759
def test_prepare_init_params_from_job_description_with_invalid_training_job():
3756-
37573760
invalid_job_description = RETURNED_JOB_DESCRIPTION.copy()
37583761
invalid_job_description["AlgorithmSpecification"] = {"TrainingInputMode": "File"}
37593762

@@ -3805,7 +3808,9 @@ def test_prepare_for_training_with_name_based_on_algorithm(sagemaker_session):
38053808

38063809

38073810
@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
3808-
def test_prepare_for_training_with_pipeline_name_in_s3_path_no_source_dir(pipeline_session,):
3811+
def test_prepare_for_training_with_pipeline_name_in_s3_path_no_source_dir(
3812+
pipeline_session,
3813+
):
38093814
# script_uri is NOT provided -> use new cache key behavior that builds path using pipeline name + code_hash
38103815
image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38"
38113816
model_uri = "s3://someprefix2/models/model.tar.gz"
@@ -3993,7 +3998,6 @@ def test_script_mode_estimator(patched_stage_user_code, sagemaker_session):
39933998
def test_script_mode_estimator_same_calls_as_framework(
39943999
patched_tar_and_upload_dir, sagemaker_session
39954000
):
3996-
39974001
patched_tar_and_upload_dir.return_value = UploadedCode(
39984002
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
39994003
)
@@ -4253,7 +4257,6 @@ def test_script_mode_estimator_tags_jumpstart_models_with_no_estimator_js_tags(
42534257
def test_all_framework_estimators_add_jumpstart_tags(
42544258
patched_repack_model, patched_upload_code, patched_tar_and_upload_dir, sagemaker_session
42554259
):
4256-
42574260
sagemaker_session.boto_region_name = REGION
42584261
sagemaker_session.sagemaker_client.describe_training_job.return_value = {
42594262
"ModelArtifacts": {"S3ModelArtifacts": "some-uri"}
@@ -4286,7 +4289,7 @@ def test_all_framework_estimators_add_jumpstart_tags(
42864289
}
42874290
jumpstart_model_uri = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz"
42884291
jumpstart_model_uri_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/model_dirs/model.tar.gz"
4289-
for (framework_estimator_class, kwargs) in framework_estimator_classes_to_kwargs.items():
4292+
for framework_estimator_class, kwargs in framework_estimator_classes_to_kwargs.items():
42904293
estimator = framework_estimator_class(
42914294
entry_point=ENTRY_POINT,
42924295
role=ROLE,
@@ -4392,7 +4395,6 @@ def test_script_mode_estimator_uses_jumpstart_base_name_with_js_models(
43924395
def test_all_framework_estimators_add_jumpstart_base_name(
43934396
patched_repack_model, patched_upload_code, patched_tar_and_upload_dir, sagemaker_session
43944397
):
4395-
43964398
sagemaker_session.boto_region_name = REGION
43974399
sagemaker_session.sagemaker_client.describe_training_job.return_value = {
43984400
"ModelArtifacts": {"S3ModelArtifacts": "some-uri"}
@@ -4425,7 +4427,7 @@ def test_all_framework_estimators_add_jumpstart_base_name(
44254427
}
44264428
jumpstart_model_uri = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz"
44274429
jumpstart_model_uri_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/model_dirs/model.tar.gz"
4428-
for (framework_estimator_class, kwargs) in framework_estimator_classes_to_kwargs.items():
4430+
for framework_estimator_class, kwargs in framework_estimator_classes_to_kwargs.items():
44294431
estimator = framework_estimator_class(
44304432
entry_point=ENTRY_POINT,
44314433
role=ROLE,

0 commit comments

Comments
 (0)