Skip to content

Commit 9881bb1

Browse files
committed
feature: smdataparallel custom mpi options support
1 parent 85321d3 commit 9881bb1

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/sagemaker/estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,6 +1954,7 @@ class Framework(EstimatorBase):
19541954
INSTANCE_TYPE = "sagemaker_instance_type"
19551955
MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host"
19561956
MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options"
1957+
SM_DDP_CUSTOM_MPI_OPTIONS = "sagemaker_distributed_dataparallel_custom_mpi_options"
19571958
CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz"
19581959

19591960
def __init__(
@@ -2629,6 +2630,10 @@ def _distribution_configuration(self, distribution):
26292630
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
26302631
distribution_config[self.LAUNCH_SM_DDP_ENV_NAME] = smdataparallel_enabled
26312632
distribution_config[self.INSTANCE_TYPE] = self.instance_type
2633+
if smdataparallel_enabled:
2634+
distribution_config[self.SM_DDP_CUSTOM_MPI_OPTIONS] = smdistributed["dataparallel"].get(
2635+
"custom_mpi_options", ""
2636+
)
26322637

26332638
return distribution_config
26342639

tests/unit/test_estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@
121121
DISTRIBUTION_MPI_ENABLED = {
122122
"mpi": {"enabled": True, "custom_mpi_options": "options", "processes_per_host": 2}
123123
}
124-
DISTRIBUTION_SM_DDP_ENABLED = {"smdistributed": {"dataparallel": {"enabled": True}}}
124+
DISTRIBUTION_SM_DDP_ENABLED = {"smdistributed": {"dataparallel": {"enabled": True,
125+
"custom_mpi_options": "options"}}}
125126

126127

127128
class DummyFramework(Framework):
@@ -3290,6 +3291,7 @@ def test_framework_distribution_configuration(sagemaker_session):
32903291
actual_ddp = framework._distribution_configuration(distribution=DISTRIBUTION_SM_DDP_ENABLED)
32913292
expected_ddp = {
32923293
"sagemaker_distributed_dataparallel_enabled": True,
3294+
"sagemaker_distributed_dataparallel_custom_mpi_options": "options",
32933295
"sagemaker_instance_type": INSTANCE_TYPE,
32943296
}
32953297
assert actual_ddp == expected_ddp

0 commit comments

Comments
 (0)