45
45
REGION = "us-east-1"
46
46
GPU = "ml.p3.2xlarge"
47
47
SUPPORTED_GPU_INSTANCE_CLASSES = {"p3" , "p3dn" , "g4dn" , "p4d" , "g5" }
48
- UNSUPPORTED_GPU_INSTANCE_CLASSES = (
49
- EC2_GPU_INSTANCE_CLASSES - SUPPORTED_GPU_INSTANCE_CLASSES
50
- )
48
+ UNSUPPORTED_GPU_INSTANCE_CLASSES = EC2_GPU_INSTANCE_CLASSES - SUPPORTED_GPU_INSTANCE_CLASSES
51
49
52
50
LIST_TAGS_RESULT = {"Tags" : [{"Key" : "TagtestKey" , "Value" : "TagtestValue" }]}
53
51
@@ -98,13 +96,9 @@ def _get_full_gpu_image_uri(version, instance_type, training_compiler_config):
98
96
)
99
97
100
98
101
- def _create_train_job (
102
- version , instance_type , training_compiler_config , instance_count = 1
103
- ):
99
+ def _create_train_job (version , instance_type , training_compiler_config , instance_count = 1 ):
104
100
return {
105
- "image_uri" : _get_full_gpu_image_uri (
106
- version , instance_type , training_compiler_config
107
- ),
101
+ "image_uri" : _get_full_gpu_image_uri (version , instance_type , training_compiler_config ),
108
102
"input_mode" : "File" ,
109
103
"input_config" : [
110
104
{
@@ -189,9 +183,7 @@ def test_unsupported_cpu_instance(cpu_instance_type, pytorch_training_compiler_v
189
183
).fit ()
190
184
191
185
192
- @pytest .mark .parametrize (
193
- "unsupported_gpu_instance_class" , UNSUPPORTED_GPU_INSTANCE_CLASSES
194
- )
186
+ @pytest .mark .parametrize ("unsupported_gpu_instance_class" , UNSUPPORTED_GPU_INSTANCE_CLASSES )
195
187
def test_unsupported_gpu_instance (
196
188
unsupported_gpu_instance_class , pytorch_training_compiler_version
197
189
):
@@ -351,19 +343,15 @@ def test_pytorchxla_distribution(
351
343
compiler_config ,
352
344
instance_count = 2 ,
353
345
)
354
- expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ][
355
- "S3Uri"
356
- ] = inputs
346
+ expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
357
347
expected_train_args ["enable_sagemaker_metrics" ] = False
358
- expected_train_args ["hyperparameters" ][
359
- TrainingCompilerConfig .HP_ENABLE_COMPILER
360
- ] = json .dumps (True )
361
- expected_train_args ["hyperparameters" ][PyTorch .LAUNCH_PT_XLA_ENV_NAME ] = json .dumps (
348
+ expected_train_args ["hyperparameters" ][TrainingCompilerConfig .HP_ENABLE_COMPILER ] = json .dumps (
362
349
True
363
350
)
364
- expected_train_args ["hyperparameters" ][
365
- TrainingCompilerConfig .HP_ENABLE_DEBUG
366
- ] = json .dumps (False )
351
+ expected_train_args ["hyperparameters" ][PyTorch .LAUNCH_PT_XLA_ENV_NAME ] = json .dumps (True )
352
+ expected_train_args ["hyperparameters" ][TrainingCompilerConfig .HP_ENABLE_DEBUG ] = json .dumps (
353
+ False
354
+ )
367
355
368
356
actual_train_args = sagemaker_session .method_calls [0 ][2 ]
369
357
assert (
@@ -411,16 +399,14 @@ def test_default_compiler_config(
411
399
expected_train_args = _create_train_job (
412
400
pytorch_training_compiler_version , instance_type , compiler_config
413
401
)
414
- expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ][
415
- "S3Uri"
416
- ] = inputs
402
+ expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
417
403
expected_train_args ["enable_sagemaker_metrics" ] = False
418
- expected_train_args ["hyperparameters" ][
419
- TrainingCompilerConfig . HP_ENABLE_COMPILER
420
- ] = json . dumps ( True )
421
- expected_train_args ["hyperparameters" ][
422
- TrainingCompilerConfig . HP_ENABLE_DEBUG
423
- ] = json . dumps ( False )
404
+ expected_train_args ["hyperparameters" ][TrainingCompilerConfig . HP_ENABLE_COMPILER ] = json . dumps (
405
+ True
406
+ )
407
+ expected_train_args ["hyperparameters" ][TrainingCompilerConfig . HP_ENABLE_DEBUG ] = json . dumps (
408
+ False
409
+ )
424
410
425
411
actual_train_args = sagemaker_session .method_calls [0 ][2 ]
426
412
assert (
@@ -465,16 +451,14 @@ def test_debug_compiler_config(
465
451
expected_train_args = _create_train_job (
466
452
pytorch_training_compiler_version , INSTANCE_TYPE , compiler_config
467
453
)
468
- expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ][
469
- "S3Uri"
470
- ] = inputs
454
+ expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
471
455
expected_train_args ["enable_sagemaker_metrics" ] = False
472
- expected_train_args ["hyperparameters" ][
473
- TrainingCompilerConfig . HP_ENABLE_COMPILER
474
- ] = json . dumps ( True )
475
- expected_train_args ["hyperparameters" ][
476
- TrainingCompilerConfig . HP_ENABLE_DEBUG
477
- ] = json . dumps ( True )
456
+ expected_train_args ["hyperparameters" ][TrainingCompilerConfig . HP_ENABLE_COMPILER ] = json . dumps (
457
+ True
458
+ )
459
+ expected_train_args ["hyperparameters" ][TrainingCompilerConfig . HP_ENABLE_DEBUG ] = json . dumps (
460
+ True
461
+ )
478
462
479
463
actual_train_args = sagemaker_session .method_calls [0 ][2 ]
480
464
assert (
@@ -519,16 +503,14 @@ def test_disable_compiler_config(
519
503
expected_train_args = _create_train_job (
520
504
pytorch_training_compiler_version , INSTANCE_TYPE , compiler_config
521
505
)
522
- expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ][
523
- "S3Uri"
524
- ] = inputs
506
+ expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
525
507
expected_train_args ["enable_sagemaker_metrics" ] = False
526
- expected_train_args ["hyperparameters" ][
527
- TrainingCompilerConfig . HP_ENABLE_COMPILER
528
- ] = json . dumps ( False )
529
- expected_train_args ["hyperparameters" ][
530
- TrainingCompilerConfig . HP_ENABLE_DEBUG
531
- ] = json . dumps ( False )
508
+ expected_train_args ["hyperparameters" ][TrainingCompilerConfig . HP_ENABLE_COMPILER ] = json . dumps (
509
+ False
510
+ )
511
+ expected_train_args ["hyperparameters" ][TrainingCompilerConfig . HP_ENABLE_DEBUG ] = json . dumps (
512
+ False
513
+ )
532
514
533
515
actual_train_args = sagemaker_session .method_calls [0 ][2 ]
534
516
assert (
@@ -582,9 +564,7 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
582
564
name = "describe_training_job" , return_value = returned_job_description
583
565
)
584
566
585
- estimator = PyTorch .attach (
586
- training_job_name = "trcomp" , sagemaker_session = sagemaker_session
587
- )
567
+ estimator = PyTorch .attach (training_job_name = "trcomp" , sagemaker_session = sagemaker_session )
588
568
assert estimator .latest_training_job .job_name == "trcomp"
589
569
assert estimator .py_version == "py38"
590
570
assert estimator .framework_version == "1.12.0"
@@ -596,12 +576,12 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
596
576
assert estimator .output_path == "s3://place/output/trcomp"
597
577
assert estimator .output_kms_key == ""
598
578
assert estimator .hyperparameters ()["training_steps" ] == "100"
599
- assert estimator .hyperparameters ()[
600
- TrainingCompilerConfig . HP_ENABLE_COMPILER
601
- ] == json . dumps ( compiler_enabled )
602
- assert estimator .hyperparameters ()[
603
- TrainingCompilerConfig . HP_ENABLE_DEBUG
604
- ] == json . dumps ( debug_enabled )
579
+ assert estimator .hyperparameters ()[TrainingCompilerConfig . HP_ENABLE_COMPILER ] == json . dumps (
580
+ compiler_enabled
581
+ )
582
+ assert estimator .hyperparameters ()[TrainingCompilerConfig . HP_ENABLE_DEBUG ] == json . dumps (
583
+ debug_enabled
584
+ )
605
585
assert estimator .source_dir == "s3://some/sourcedir.tar.gz"
606
586
assert estimator .entry_point == "iris-dnn-classifier.py"
607
587
0 commit comments