@@ -81,9 +81,7 @@ def fixture_sagemaker_session():
81
81
return session
82
82
83
83
84
- def _get_full_gpu_image_uri (
85
- version , instance_type , training_compiler_config
86
- ):
84
+ def _get_full_gpu_image_uri (version , instance_type , training_compiler_config ):
87
85
return image_uris .retrieve (
88
86
"pytorch-training-compiler" ,
89
87
REGION ,
@@ -96,13 +94,9 @@ def _get_full_gpu_image_uri(
96
94
)
97
95
98
96
99
- def _create_train_job (
100
- version , instance_type , training_compiler_config , instance_count = 1
101
- ):
97
+ def _create_train_job (version , instance_type , training_compiler_config , instance_count = 1 ):
102
98
return {
103
- "image_uri" : _get_full_gpu_image_uri (
104
- version , instance_type , training_compiler_config
105
- ),
99
+ "image_uri" : _get_full_gpu_image_uri (version , instance_type , training_compiler_config ),
106
100
"input_mode" : "File" ,
107
101
"input_config" : [
108
102
{
@@ -150,15 +144,11 @@ def _create_train_job(
150
144
"RuleParameters" : {"rule_to_invoke" : "ProfilerReport" },
151
145
}
152
146
],
153
- "profiler_config" : {
154
- "S3OutputPath" : "s3://{}/" .format (BUCKET_NAME ),
155
- },
147
+ "profiler_config" : {"S3OutputPath" : "s3://{}/" .format (BUCKET_NAME )},
156
148
}
157
149
158
150
159
- def test_unsupported_BYOC (
160
- pytorch_training_compiler_version ,
161
- ):
151
+ def test_unsupported_BYOC (pytorch_training_compiler_version ,):
162
152
byoc = (
163
153
"1.dkr.ecr.us-east-1.amazonaws.com/pytorch-trcomp-training:"
164
154
"1.12.0-"
@@ -179,10 +169,7 @@ def test_unsupported_BYOC(
179
169
).fit ()
180
170
181
171
182
- def test_unsupported_cpu_instance (
183
- cpu_instance_type ,
184
- pytorch_training_compiler_version ,
185
- ):
172
+ def test_unsupported_cpu_instance (cpu_instance_type , pytorch_training_compiler_version ):
186
173
with pytest .raises (ValueError ):
187
174
PyTorch (
188
175
py_version = "py38" ,
@@ -198,8 +185,7 @@ def test_unsupported_cpu_instance(
198
185
199
186
@pytest .mark .parametrize ("unsupported_gpu_instance_class" , UNSUPPORTED_GPU_INSTANCE_CLASSES )
200
187
def test_unsupported_gpu_instance (
201
- unsupported_gpu_instance_class ,
202
- pytorch_training_compiler_version ,
188
+ unsupported_gpu_instance_class , pytorch_training_compiler_version
203
189
):
204
190
with pytest .raises (ValueError ):
205
191
PyTorch (
@@ -228,9 +214,7 @@ def test_unsupported_framework_version():
228
214
).fit ()
229
215
230
216
231
- def test_unsupported_python_2 (
232
- pytorch_training_compiler_version ,
233
- ):
217
+ def test_unsupported_python_2 (pytorch_training_compiler_version ,):
234
218
with pytest .raises (ValueError ):
235
219
PyTorch (
236
220
py_version = "py27" ,
@@ -244,9 +228,7 @@ def test_unsupported_python_2(
244
228
).fit ()
245
229
246
230
247
- def test_unsupported_instance_group (
248
- pytorch_training_compiler_version ,
249
- ):
231
+ def test_unsupported_instance_group (pytorch_training_compiler_version ,):
250
232
if Version (pytorch_training_compiler_version ) < Version ("1.12" ):
251
233
pytest .skip ("This test is intended for PyTorch 1.12 and above" )
252
234
with pytest .raises (ValueError ):
@@ -264,9 +246,7 @@ def test_unsupported_instance_group(
264
246
).fit ()
265
247
266
248
267
- def test_unsupported_distribution (
268
- pytorch_training_compiler_version ,
269
- ):
249
+ def test_unsupported_distribution (pytorch_training_compiler_version ,):
270
250
if Version (pytorch_training_compiler_version ) < Version ("1.12" ):
271
251
pytest .skip ("This test is intended for PyTorch 1.12 and above" )
272
252
with pytest .raises (ValueError ):
@@ -316,11 +296,7 @@ def test_unsupported_distribution(
316
296
@patch ("time.time" , return_value = TIME )
317
297
@pytest .mark .parametrize ("instance_class" , SUPPORTED_GPU_INSTANCE_CLASSES )
318
298
def test_pytorchxla_distribution (
319
- time ,
320
- name_from_base ,
321
- sagemaker_session ,
322
- pytorch_training_compiler_version ,
323
- instance_class ,
299
+ time , name_from_base , sagemaker_session , pytorch_training_compiler_version , instance_class
324
300
):
325
301
if Version (pytorch_training_compiler_version ) < Version ("1.12" ):
326
302
pytest .skip ("This test is intended for PyTorch 1.12 and above" )
@@ -350,10 +326,7 @@ def test_pytorchxla_distribution(
350
326
assert boto_call_names == ["resource" ]
351
327
352
328
expected_train_args = _create_train_job (
353
- pytorch_training_compiler_version ,
354
- instance_type ,
355
- compiler_config ,
356
- instance_count = 2 ,
329
+ pytorch_training_compiler_version , instance_type , compiler_config , instance_count = 2
357
330
)
358
331
expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
359
332
expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -377,11 +350,7 @@ def test_pytorchxla_distribution(
377
350
@patch ("time.time" , return_value = TIME )
378
351
@pytest .mark .parametrize ("instance_class" , SUPPORTED_GPU_INSTANCE_CLASSES )
379
352
def test_default_compiler_config (
380
- time ,
381
- name_from_base ,
382
- sagemaker_session ,
383
- pytorch_training_compiler_version ,
384
- instance_class ,
353
+ time , name_from_base , sagemaker_session , pytorch_training_compiler_version , instance_class
385
354
):
386
355
compiler_config = TrainingCompilerConfig ()
387
356
instance_type = f"ml.{ instance_class } .xlarge"
@@ -408,9 +377,7 @@ def test_default_compiler_config(
408
377
assert boto_call_names == ["resource" ]
409
378
410
379
expected_train_args = _create_train_job (
411
- pytorch_training_compiler_version ,
412
- instance_type ,
413
- compiler_config ,
380
+ pytorch_training_compiler_version , instance_type , compiler_config
414
381
)
415
382
expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
416
383
expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -432,10 +399,7 @@ def test_default_compiler_config(
432
399
@patch ("sagemaker.estimator.name_from_base" , return_value = JOB_NAME )
433
400
@patch ("time.time" , return_value = TIME )
434
401
def test_debug_compiler_config (
435
- time ,
436
- name_from_base ,
437
- sagemaker_session ,
438
- pytorch_training_compiler_version ,
402
+ time , name_from_base , sagemaker_session , pytorch_training_compiler_version
439
403
):
440
404
compiler_config = TrainingCompilerConfig (debug = True )
441
405
@@ -461,9 +425,7 @@ def test_debug_compiler_config(
461
425
assert boto_call_names == ["resource" ]
462
426
463
427
expected_train_args = _create_train_job (
464
- pytorch_training_compiler_version ,
465
- INSTANCE_TYPE ,
466
- compiler_config ,
428
+ pytorch_training_compiler_version , INSTANCE_TYPE , compiler_config
467
429
)
468
430
expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
469
431
expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -485,10 +447,7 @@ def test_debug_compiler_config(
485
447
@patch ("sagemaker.estimator.name_from_base" , return_value = JOB_NAME )
486
448
@patch ("time.time" , return_value = TIME )
487
449
def test_disable_compiler_config (
488
- time ,
489
- name_from_base ,
490
- sagemaker_session ,
491
- pytorch_training_compiler_version ,
450
+ time , name_from_base , sagemaker_session , pytorch_training_compiler_version
492
451
):
493
452
compiler_config = TrainingCompilerConfig (enabled = False )
494
453
@@ -514,9 +473,7 @@ def test_disable_compiler_config(
514
473
assert boto_call_names == ["resource" ]
515
474
516
475
expected_train_args = _create_train_job (
517
- pytorch_training_compiler_version ,
518
- INSTANCE_TYPE ,
519
- compiler_config ,
476
+ pytorch_training_compiler_version , INSTANCE_TYPE , compiler_config
520
477
)
521
478
expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
522
479
expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -536,11 +493,7 @@ def test_disable_compiler_config(
536
493
@pytest .mark .parametrize (
537
494
["compiler_enabled" , "debug_enabled" ], [(True , False ), (True , True ), (False , False )]
538
495
)
539
- def test_attach (
540
- sagemaker_session ,
541
- compiler_enabled ,
542
- debug_enabled ,
543
- ):
496
+ def test_attach (sagemaker_session , compiler_enabled , debug_enabled ):
544
497
training_image = (
545
498
"1.dkr.ecr.us-east-1.amazonaws.com/pytorch-trcomp-training:"
546
499
"1.12.0-"
@@ -600,8 +553,7 @@ def test_attach(
600
553
601
554
602
555
def test_register_hf_pytorch_model_auto_infer_framework (
603
- sagemaker_session ,
604
- pytorch_training_compiler_version ,
556
+ sagemaker_session , pytorch_training_compiler_version
605
557
):
606
558
607
559
model_package_group_name = "test-pt-register-model"
@@ -637,7 +589,7 @@ def test_register_hf_pytorch_model_auto_infer_framework(
637
589
"ModelDataUrl" : ANY ,
638
590
"Framework" : "PYTORCH" ,
639
591
"FrameworkVersion" : pytorch_training_compiler_version ,
640
- },
592
+ }
641
593
],
642
594
"content_types" : content_types ,
643
595
"response_types" : response_types ,
0 commit comments