@@ -553,11 +553,15 @@ def test_validate_version_or_image_args_raises():
553
553
554
554
def test_validate_smdistributed_not_raises ():
555
555
smdataparallel_enabled = {"smdistributed" : {"dataparallel" : {"enabled" : True }}}
556
+ smdataparallel_enabled_custom_mpi = {
557
+ "smdistributed" : {"dataparallel" : {"enabled" : True , "custom_mpi_options" : "--verbose" }}
558
+ }
556
559
smdataparallel_disabled = {"smdistributed" : {"dataparallel" : {"enabled" : False }}}
557
560
instance_types = list (fw_utils .SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES )
558
561
559
562
good_args = [
560
563
(smdataparallel_enabled , "custom-container" ),
564
+ (smdataparallel_enabled_custom_mpi , "custom-container" ),
561
565
(smdataparallel_disabled , "custom-container" ),
562
566
]
563
567
frameworks = ["tensorflow" , "pytorch" ]
@@ -576,17 +580,17 @@ def test_validate_smdistributed_not_raises():
576
580
577
581
def test_validate_smdistributed_raises ():
578
582
bad_args = [
579
- {"smdistributed" : {"dataparallel" : {"enabled" : True }}},
580
583
{"smdistributed" : "dummy" },
581
584
{"smdistributed" : {"dummy" }},
582
585
{"smdistributed" : {"dummy" : "val" }},
583
586
{"smdistributed" : {"dummy" : {"enabled" : True }}},
584
587
]
588
+ instance_types = list (fw_utils .SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES )
585
589
frameworks = ["tensorflow" , "pytorch" ]
586
- for framework , distribution in product (frameworks , bad_args ):
590
+ for framework , distribution , instance_type in product (frameworks , bad_args , instance_types ):
587
591
with pytest .raises (ValueError ):
588
592
fw_utils .validate_smdistributed (
589
- instance_type = None ,
593
+ instance_type = instance_type ,
590
594
framework_name = framework ,
591
595
framework_version = None ,
592
596
py_version = None ,
@@ -624,6 +628,9 @@ def test_validate_smdataparallel_args_raises():
624
628
625
629
def test_validate_smdataparallel_args_not_raises ():
626
630
smdataparallel_enabled = {"smdistributed" : {"dataparallel" : {"enabled" : True }}}
631
+ smdataparallel_enabled_custom_mpi = {
632
+ "smdistributed" : {"dataparallel" : {"enabled" : True , "custom_mpi_options" : "--verbose" }}
633
+ }
627
634
smdataparallel_disabled = {"smdistributed" : {"dataparallel" : {"enabled" : False }}}
628
635
629
636
# Cases {PT|TF2}
@@ -643,6 +650,8 @@ def test_validate_smdataparallel_args_not_raises():
643
650
("ml.p3.16xlarge" , "pytorch" , "1.7" , "py3" , smdataparallel_enabled ),
644
651
("ml.p3.16xlarge" , "pytorch" , "1.8.0" , "py3" , smdataparallel_enabled ),
645
652
("ml.p3.16xlarge" , "pytorch" , "1.8" , "py3" , smdataparallel_enabled ),
653
+ ("ml.p3.16xlarge" , "tensorflow" , "2.4.1" , "py3" , smdataparallel_enabled_custom_mpi ),
654
+ ("ml.p3.16xlarge" , "pytorch" , "1.8.0" , "py3" , smdataparallel_enabled_custom_mpi ),
646
655
]
647
656
for instance_type , framework_name , framework_version , py_version , distribution in good_args :
648
657
fw_utils ._validate_smdataparallel_args (
0 commit comments