Skip to content

Commit b464716

Browse files
committed
Added unit and integration test
1 parent 2f5aed4 commit b464716

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

tests/integ/test_smdataparallel_pt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def test_smdataparallel_pt_mnist(
4747
sagemaker_session=sagemaker_session,
4848
framework_version=pytorch_training_latest_version,
4949
py_version=pytorch_training_latest_py_version,
50-
distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
50+
distribution={
51+
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "--verbose"}}
52+
},
5153
)
5254

5355
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):

tests/integ/test_smdataparallel_tf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def test_smdataparallel_tf_mnist(
4747
sagemaker_session=sagemaker_session,
4848
framework_version=tensorflow_training_latest_version,
4949
py_version=tensorflow_training_latest_py_version,
50-
distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
50+
distribution={
51+
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "--verbose"}}
52+
},
5153
)
5254

5355
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):

tests/unit/test_fw_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -553,11 +553,15 @@ def test_validate_version_or_image_args_raises():
553553

554554
def test_validate_smdistributed_not_raises():
555555
smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}}
556+
smdataparallel_enabled_custom_mpi = {
557+
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "--verbose"}}
558+
}
556559
smdataparallel_disabled = {"smdistributed": {"dataparallel": {"enabled": False}}}
557560
instance_types = list(fw_utils.SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES)
558561

559562
good_args = [
560563
(smdataparallel_enabled, "custom-container"),
564+
(smdataparallel_enabled_custom_mpi, "custom-container"),
561565
(smdataparallel_disabled, "custom-container"),
562566
]
563567
frameworks = ["tensorflow", "pytorch"]
@@ -576,17 +580,17 @@ def test_validate_smdistributed_not_raises():
576580

577581
def test_validate_smdistributed_raises():
578582
bad_args = [
579-
{"smdistributed": {"dataparallel": {"enabled": True}}},
580583
{"smdistributed": "dummy"},
581584
{"smdistributed": {"dummy"}},
582585
{"smdistributed": {"dummy": "val"}},
583586
{"smdistributed": {"dummy": {"enabled": True}}},
584587
]
588+
instance_types = list(fw_utils.SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES)
585589
frameworks = ["tensorflow", "pytorch"]
586-
for framework, distribution in product(frameworks, bad_args):
590+
for framework, distribution, instance_type in product(frameworks, bad_args, instance_types):
587591
with pytest.raises(ValueError):
588592
fw_utils.validate_smdistributed(
589-
instance_type=None,
593+
instance_type=instance_type,
590594
framework_name=framework,
591595
framework_version=None,
592596
py_version=None,
@@ -624,6 +628,9 @@ def test_validate_smdataparallel_args_raises():
624628

625629
def test_validate_smdataparallel_args_not_raises():
626630
smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}}
631+
smdataparallel_enabled_custom_mpi = {
632+
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "--verbose"}}
633+
}
627634
smdataparallel_disabled = {"smdistributed": {"dataparallel": {"enabled": False}}}
628635

629636
# Cases {PT|TF2}
@@ -643,6 +650,8 @@ def test_validate_smdataparallel_args_not_raises():
643650
("ml.p3.16xlarge", "pytorch", "1.7", "py3", smdataparallel_enabled),
644651
("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled),
645652
("ml.p3.16xlarge", "pytorch", "1.8", "py3", smdataparallel_enabled),
653+
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
654+
("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled_custom_mpi),
646655
]
647656
for instance_type, framework_name, framework_version, py_version, distribution in good_args:
648657
fw_utils._validate_smdataparallel_args(

0 commit comments

Comments
 (0)