Skip to content

Commit d8d507a

Browse files
committed
Black-formatting
1 parent 6462e11 commit d8d507a

File tree

2 files changed

+24
-72
lines changed

2 files changed

+24
-72
lines changed

src/sagemaker/pytorch/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
image_uri: Optional[Union[str, PipelineVariable]] = None,
5454
distribution: Optional[Dict] = None,
5555
compiler_config: Optional[TrainingCompilerConfig] = None,
56-
**kwargs
56+
**kwargs,
5757
):
5858
"""This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.
5959
@@ -351,7 +351,7 @@ def create_model(
351351
entry_point=None,
352352
source_dir=None,
353353
dependencies=None,
354-
**kwargs
354+
**kwargs,
355355
):
356356
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.
357357
@@ -402,7 +402,7 @@ def create_model(
402402
sagemaker_session=self.sagemaker_session,
403403
vpc_config=self.get_vpc_config(vpc_config_override),
404404
dependencies=(dependencies or self.dependencies),
405-
**kwargs
405+
**kwargs,
406406
)
407407

408408
@classmethod

tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py

Lines changed: 21 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,7 @@ def fixture_sagemaker_session():
8181
return session
8282

8383

84-
def _get_full_gpu_image_uri(
85-
version, instance_type, training_compiler_config
86-
):
84+
def _get_full_gpu_image_uri(version, instance_type, training_compiler_config):
8785
return image_uris.retrieve(
8886
"pytorch-training-compiler",
8987
REGION,
@@ -96,13 +94,9 @@ def _get_full_gpu_image_uri(
9694
)
9795

9896

99-
def _create_train_job(
100-
version, instance_type, training_compiler_config, instance_count=1
101-
):
97+
def _create_train_job(version, instance_type, training_compiler_config, instance_count=1):
10298
return {
103-
"image_uri": _get_full_gpu_image_uri(
104-
version, instance_type, training_compiler_config
105-
),
99+
"image_uri": _get_full_gpu_image_uri(version, instance_type, training_compiler_config),
106100
"input_mode": "File",
107101
"input_config": [
108102
{
@@ -150,15 +144,11 @@ def _create_train_job(
150144
"RuleParameters": {"rule_to_invoke": "ProfilerReport"},
151145
}
152146
],
153-
"profiler_config": {
154-
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
155-
},
147+
"profiler_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)},
156148
}
157149

158150

159-
def test_unsupported_BYOC(
160-
pytorch_training_compiler_version,
161-
):
151+
def test_unsupported_BYOC(pytorch_training_compiler_version,):
162152
byoc = (
163153
"1.dkr.ecr.us-east-1.amazonaws.com/pytorch-trcomp-training:"
164154
"1.12.0-"
@@ -179,10 +169,7 @@ def test_unsupported_BYOC(
179169
).fit()
180170

181171

182-
def test_unsupported_cpu_instance(
183-
cpu_instance_type,
184-
pytorch_training_compiler_version,
185-
):
172+
def test_unsupported_cpu_instance(cpu_instance_type, pytorch_training_compiler_version):
186173
with pytest.raises(ValueError):
187174
PyTorch(
188175
py_version="py38",
@@ -198,8 +185,7 @@ def test_unsupported_cpu_instance(
198185

199186
@pytest.mark.parametrize("unsupported_gpu_instance_class", UNSUPPORTED_GPU_INSTANCE_CLASSES)
200187
def test_unsupported_gpu_instance(
201-
unsupported_gpu_instance_class,
202-
pytorch_training_compiler_version,
188+
unsupported_gpu_instance_class, pytorch_training_compiler_version
203189
):
204190
with pytest.raises(ValueError):
205191
PyTorch(
@@ -228,9 +214,7 @@ def test_unsupported_framework_version():
228214
).fit()
229215

230216

231-
def test_unsupported_python_2(
232-
pytorch_training_compiler_version,
233-
):
217+
def test_unsupported_python_2(pytorch_training_compiler_version,):
234218
with pytest.raises(ValueError):
235219
PyTorch(
236220
py_version="py27",
@@ -244,9 +228,7 @@ def test_unsupported_python_2(
244228
).fit()
245229

246230

247-
def test_unsupported_instance_group(
248-
pytorch_training_compiler_version,
249-
):
231+
def test_unsupported_instance_group(pytorch_training_compiler_version,):
250232
if Version(pytorch_training_compiler_version) < Version("1.12"):
251233
pytest.skip("This test is intended for PyTorch 1.12 and above")
252234
with pytest.raises(ValueError):
@@ -264,9 +246,7 @@ def test_unsupported_instance_group(
264246
).fit()
265247

266248

267-
def test_unsupported_distribution(
268-
pytorch_training_compiler_version,
269-
):
249+
def test_unsupported_distribution(pytorch_training_compiler_version,):
270250
if Version(pytorch_training_compiler_version) < Version("1.12"):
271251
pytest.skip("This test is intended for PyTorch 1.12 and above")
272252
with pytest.raises(ValueError):
@@ -316,11 +296,7 @@ def test_unsupported_distribution(
316296
@patch("time.time", return_value=TIME)
317297
@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES)
318298
def test_pytorchxla_distribution(
319-
time,
320-
name_from_base,
321-
sagemaker_session,
322-
pytorch_training_compiler_version,
323-
instance_class,
299+
time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class
324300
):
325301
if Version(pytorch_training_compiler_version) < Version("1.12"):
326302
pytest.skip("This test is intended for PyTorch 1.12 and above")
@@ -350,10 +326,7 @@ def test_pytorchxla_distribution(
350326
assert boto_call_names == ["resource"]
351327

352328
expected_train_args = _create_train_job(
353-
pytorch_training_compiler_version,
354-
instance_type,
355-
compiler_config,
356-
instance_count=2,
329+
pytorch_training_compiler_version, instance_type, compiler_config, instance_count=2
357330
)
358331
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
359332
expected_train_args["enable_sagemaker_metrics"] = False
@@ -377,11 +350,7 @@ def test_pytorchxla_distribution(
377350
@patch("time.time", return_value=TIME)
378351
@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES)
379352
def test_default_compiler_config(
380-
time,
381-
name_from_base,
382-
sagemaker_session,
383-
pytorch_training_compiler_version,
384-
instance_class,
353+
time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class
385354
):
386355
compiler_config = TrainingCompilerConfig()
387356
instance_type = f"ml.{instance_class}.xlarge"
@@ -408,9 +377,7 @@ def test_default_compiler_config(
408377
assert boto_call_names == ["resource"]
409378

410379
expected_train_args = _create_train_job(
411-
pytorch_training_compiler_version,
412-
instance_type,
413-
compiler_config,
380+
pytorch_training_compiler_version, instance_type, compiler_config
414381
)
415382
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
416383
expected_train_args["enable_sagemaker_metrics"] = False
@@ -432,10 +399,7 @@ def test_default_compiler_config(
432399
@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
433400
@patch("time.time", return_value=TIME)
434401
def test_debug_compiler_config(
435-
time,
436-
name_from_base,
437-
sagemaker_session,
438-
pytorch_training_compiler_version,
402+
time, name_from_base, sagemaker_session, pytorch_training_compiler_version
439403
):
440404
compiler_config = TrainingCompilerConfig(debug=True)
441405

@@ -461,9 +425,7 @@ def test_debug_compiler_config(
461425
assert boto_call_names == ["resource"]
462426

463427
expected_train_args = _create_train_job(
464-
pytorch_training_compiler_version,
465-
INSTANCE_TYPE,
466-
compiler_config,
428+
pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config
467429
)
468430
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
469431
expected_train_args["enable_sagemaker_metrics"] = False
@@ -485,10 +447,7 @@ def test_debug_compiler_config(
485447
@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
486448
@patch("time.time", return_value=TIME)
487449
def test_disable_compiler_config(
488-
time,
489-
name_from_base,
490-
sagemaker_session,
491-
pytorch_training_compiler_version,
450+
time, name_from_base, sagemaker_session, pytorch_training_compiler_version
492451
):
493452
compiler_config = TrainingCompilerConfig(enabled=False)
494453

@@ -514,9 +473,7 @@ def test_disable_compiler_config(
514473
assert boto_call_names == ["resource"]
515474

516475
expected_train_args = _create_train_job(
517-
pytorch_training_compiler_version,
518-
INSTANCE_TYPE,
519-
compiler_config,
476+
pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config
520477
)
521478
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
522479
expected_train_args["enable_sagemaker_metrics"] = False
@@ -536,11 +493,7 @@ def test_disable_compiler_config(
536493
@pytest.mark.parametrize(
537494
["compiler_enabled", "debug_enabled"], [(True, False), (True, True), (False, False)]
538495
)
539-
def test_attach(
540-
sagemaker_session,
541-
compiler_enabled,
542-
debug_enabled,
543-
):
496+
def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
544497
training_image = (
545498
"1.dkr.ecr.us-east-1.amazonaws.com/pytorch-trcomp-training:"
546499
"1.12.0-"
@@ -600,8 +553,7 @@ def test_attach(
600553

601554

602555
def test_register_hf_pytorch_model_auto_infer_framework(
603-
sagemaker_session,
604-
pytorch_training_compiler_version,
556+
sagemaker_session, pytorch_training_compiler_version
605557
):
606558

607559
model_package_group_name = "test-pt-register-model"
@@ -637,7 +589,7 @@ def test_register_hf_pytorch_model_auto_infer_framework(
637589
"ModelDataUrl": ANY,
638590
"Framework": "PYTORCH",
639591
"FrameworkVersion": pytorch_training_compiler_version,
640-
},
592+
}
641593
],
642594
"content_types": content_types,
643595
"response_types": response_types,

0 commit comments

Comments
 (0)