@@ -632,10 +632,15 @@ def test_validate_smdataparallel_args_not_raises():
632
632
(None , None , None , None , smdataparallel_disabled ),
633
633
("ml.p3.16xlarge" , "tensorflow" , "2.3.1" , "py3" , smdataparallel_enabled ),
634
634
("ml.p3.16xlarge" , "tensorflow" , "2.3.2" , "py3" , smdataparallel_enabled ),
635
+ ("ml.p3.16xlarge" , "tensorflow" , "2.3" , "py3" , smdataparallel_enabled ),
635
636
("ml.p3.16xlarge" , "tensorflow" , "2.4.1" , "py3" , smdataparallel_enabled ),
637
+ ("ml.p3.16xlarge" , "tensorflow" , "2.4" , "py3" , smdataparallel_enabled ),
636
638
("ml.p3.16xlarge" , "pytorch" , "1.6.0" , "py3" , smdataparallel_enabled ),
639
+ ("ml.p3.16xlarge" , "pytorch" , "1.6" , "py3" , smdataparallel_enabled ),
637
640
("ml.p3.16xlarge" , "pytorch" , "1.7.1" , "py3" , smdataparallel_enabled ),
641
+ ("ml.p3.16xlarge" , "pytorch" , "1.7" , "py3" , smdataparallel_enabled ),
638
642
("ml.p3.16xlarge" , "pytorch" , "1.8.0" , "py3" , smdataparallel_enabled ),
643
+ ("ml.p3.16xlarge" , "pytorch" , "1.8" , "py3" , smdataparallel_enabled ),
639
644
]
640
645
for instance_type , framework_name , framework_version , py_version , distribution in good_args :
641
646
fw_utils ._validate_smdataparallel_args (
0 commit comments