Skip to content

Commit 54b7200

Browse files
committed
feature: smdataparallel custom mpi options support
1 parent b7b4549 commit 54b7200

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/sagemaker/estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,6 +1944,7 @@ class Framework(EstimatorBase):
19441944
INSTANCE_TYPE = "sagemaker_instance_type"
19451945
MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host"
19461946
MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options"
1947+
SM_DDP_CUSTOM_MPI_OPTIONS = "sagemaker_distributed_dataparallel_custom_mpi_options"
19471948
CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz"
19481949

19491950
def __init__(
@@ -2605,6 +2606,10 @@ def _distribution_configuration(self, distribution):
26052606
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
26062607
distribution_config[self.LAUNCH_SM_DDP_ENV_NAME] = smdataparallel_enabled
26072608
distribution_config[self.INSTANCE_TYPE] = self.instance_type
2609+
if smdataparallel_enabled:
2610+
distribution_config[self.SM_DDP_CUSTOM_MPI_OPTIONS] = smdistributed["dataparallel"].get(
2611+
"custom_mpi_options", ""
2612+
)
26082613

26092614
return distribution_config
26102615

tests/unit/test_estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@
120120
DISTRIBUTION_MPI_ENABLED = {
121121
"mpi": {"enabled": True, "custom_mpi_options": "options", "processes_per_host": 2}
122122
}
123-
DISTRIBUTION_SM_DDP_ENABLED = {"smdistributed": {"dataparallel": {"enabled": True}}}
123+
DISTRIBUTION_SM_DDP_ENABLED = {"smdistributed": {"dataparallel": {"enabled": True,
124+
"custom_mpi_options": "options"}}}
124125

125126

126127
class DummyFramework(Framework):
@@ -3267,6 +3268,7 @@ def test_framework_distribution_configuration(sagemaker_session):
32673268
actual_ddp = framework._distribution_configuration(distribution=DISTRIBUTION_SM_DDP_ENABLED)
32683269
expected_ddp = {
32693270
"sagemaker_distributed_dataparallel_enabled": True,
3271+
"sagemaker_distributed_dataparallel_custom_mpi_options": "options",
32703272
"sagemaker_instance_type": INSTANCE_TYPE,
32713273
}
32723274
assert actual_ddp == expected_ddp

0 commit comments

Comments
 (0)