Skip to content

Commit 44fbcc9

Browse files
feature: enable smdataparallel custom mpi options support (#2255)
Co-authored-by: Ahsan Khan <[email protected]>
1 parent f7161f0 commit 44fbcc9

File tree

5 files changed

+27
-7
lines changed

5 files changed

+27
-7
lines changed

src/sagemaker/estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,6 +1954,7 @@ class Framework(EstimatorBase):
19541954
INSTANCE_TYPE = "sagemaker_instance_type"
19551955
MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host"
19561956
MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options"
1957+
SM_DDP_CUSTOM_MPI_OPTIONS = "sagemaker_distributed_dataparallel_custom_mpi_options"
19571958
CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz"
19581959

19591960
def __init__(
@@ -2629,6 +2630,10 @@ def _distribution_configuration(self, distribution):
26292630
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
26302631
distribution_config[self.LAUNCH_SM_DDP_ENV_NAME] = smdataparallel_enabled
26312632
distribution_config[self.INSTANCE_TYPE] = self.instance_type
2633+
if smdataparallel_enabled:
2634+
distribution_config[self.SM_DDP_CUSTOM_MPI_OPTIONS] = smdistributed[
2635+
"dataparallel"
2636+
].get("custom_mpi_options", "")
26322637

26332638
return distribution_config
26342639

tests/integ/test_smdataparallel_pt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def test_smdataparallel_pt_mnist(
4747
sagemaker_session=sagemaker_session,
4848
framework_version=pytorch_training_latest_version,
4949
py_version=pytorch_training_latest_py_version,
50-
distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
50+
distribution={
51+
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "--verbose"}}
52+
},
5153
)
5254

5355
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):

tests/integ/test_smdataparallel_tf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
integ.test_region() not in integ.DATA_PARALLEL_TESTING_REGIONS,
3232
reason="Only allow this test to run in IAD and CMH to limit usage of p3.16xlarge",
3333
)
34-
@pytest.mark.skip("Failing due to bad DLC image release. Disable temporarily.")
3534
def test_smdataparallel_tf_mnist(
3635
sagemaker_session,
3736
tensorflow_training_latest_version,
@@ -47,7 +46,9 @@ def test_smdataparallel_tf_mnist(
4746
sagemaker_session=sagemaker_session,
4847
framework_version=tensorflow_training_latest_version,
4948
py_version=tensorflow_training_latest_py_version,
50-
distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
49+
distribution={
50+
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "--verbose"}}
51+
},
5152
)
5253

5354
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):

tests/unit/test_estimator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@
121121
DISTRIBUTION_MPI_ENABLED = {
122122
"mpi": {"enabled": True, "custom_mpi_options": "options", "processes_per_host": 2}
123123
}
124-
DISTRIBUTION_SM_DDP_ENABLED = {"smdistributed": {"dataparallel": {"enabled": True}}}
124+
DISTRIBUTION_SM_DDP_ENABLED = {
125+
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "options"}}
126+
}
125127

126128

127129
class DummyFramework(Framework):
@@ -3290,6 +3292,7 @@ def test_framework_distribution_configuration(sagemaker_session):
32903292
actual_ddp = framework._distribution_configuration(distribution=DISTRIBUTION_SM_DDP_ENABLED)
32913293
expected_ddp = {
32923294
"sagemaker_distributed_dataparallel_enabled": True,
3295+
"sagemaker_distributed_dataparallel_custom_mpi_options": "options",
32933296
"sagemaker_instance_type": INSTANCE_TYPE,
32943297
}
32953298
assert actual_ddp == expected_ddp

tests/unit/test_fw_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -553,11 +553,15 @@ def test_validate_version_or_image_args_raises():
553553

554554
def test_validate_smdistributed_not_raises():
555555
smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}}
556+
smdataparallel_enabled_custom_mpi = {
557+
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "--verbose"}}
558+
}
556559
smdataparallel_disabled = {"smdistributed": {"dataparallel": {"enabled": False}}}
557560
instance_types = list(fw_utils.SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES)
558561

559562
good_args = [
560563
(smdataparallel_enabled, "custom-container"),
564+
(smdataparallel_enabled_custom_mpi, "custom-container"),
561565
(smdataparallel_disabled, "custom-container"),
562566
]
563567
frameworks = ["tensorflow", "pytorch"]
@@ -576,17 +580,17 @@ def test_validate_smdistributed_not_raises():
576580

577581
def test_validate_smdistributed_raises():
578582
bad_args = [
579-
{"smdistributed": {"dataparallel": {"enabled": True}}},
580583
{"smdistributed": "dummy"},
581584
{"smdistributed": {"dummy"}},
582585
{"smdistributed": {"dummy": "val"}},
583586
{"smdistributed": {"dummy": {"enabled": True}}},
584587
]
588+
instance_types = list(fw_utils.SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES)
585589
frameworks = ["tensorflow", "pytorch"]
586-
for framework, distribution in product(frameworks, bad_args):
590+
for framework, distribution, instance_type in product(frameworks, bad_args, instance_types):
587591
with pytest.raises(ValueError):
588592
fw_utils.validate_smdistributed(
589-
instance_type=None,
593+
instance_type=instance_type,
590594
framework_name=framework,
591595
framework_version=None,
592596
py_version=None,
@@ -624,6 +628,9 @@ def test_validate_smdataparallel_args_raises():
624628

625629
def test_validate_smdataparallel_args_not_raises():
626630
smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}}
631+
smdataparallel_enabled_custom_mpi = {
632+
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "--verbose"}}
633+
}
627634
smdataparallel_disabled = {"smdistributed": {"dataparallel": {"enabled": False}}}
628635

629636
# Cases {PT|TF2}
@@ -644,6 +651,8 @@ def test_validate_smdataparallel_args_not_raises():
644651
("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled),
645652
("ml.p3.16xlarge", "pytorch", "1.8.1", "py3", smdataparallel_enabled),
646653
("ml.p3.16xlarge", "pytorch", "1.8", "py3", smdataparallel_enabled),
654+
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
655+
("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled_custom_mpi),
647656
]
648657
for instance_type, framework_name, framework_version, py_version, distribution in good_args:
649658
fw_utils._validate_smdataparallel_args(

0 commit comments

Comments
 (0)