File tree Expand file tree Collapse file tree 2 files changed +8
-1
lines changed Expand file tree Collapse file tree 2 files changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -1944,6 +1944,7 @@ class Framework(EstimatorBase):
1944
1944
INSTANCE_TYPE = "sagemaker_instance_type"
1945
1945
MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host"
1946
1946
MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options"
1947
+ SM_DDP_CUSTOM_MPI_OPTIONS = "sagemaker_distributed_dataparallel_custom_mpi_options"
1947
1948
CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz"
1948
1949
1949
1950
def __init__ (
@@ -2605,6 +2606,10 @@ def _distribution_configuration(self, distribution):
2605
2606
smdataparallel_enabled = smdistributed .get ("dataparallel" , {}).get ("enabled" , False )
2606
2607
distribution_config [self .LAUNCH_SM_DDP_ENV_NAME ] = smdataparallel_enabled
2607
2608
distribution_config [self .INSTANCE_TYPE ] = self .instance_type
2609
+ if smdataparallel_enabled :
2610
+ distribution_config [self .SM_DDP_CUSTOM_MPI_OPTIONS ] = smdistributed ["dataparallel" ].get (
2611
+ "custom_mpi_options" , ""
2612
+ )
2608
2613
2609
2614
return distribution_config
2610
2615
Original file line number Diff line number Diff line change 120
120
DISTRIBUTION_MPI_ENABLED = {
121
121
"mpi" : {"enabled" : True , "custom_mpi_options" : "options" , "processes_per_host" : 2 }
122
122
}
123
- DISTRIBUTION_SM_DDP_ENABLED = {"smdistributed" : {"dataparallel" : {"enabled" : True }}}
123
+ DISTRIBUTION_SM_DDP_ENABLED = {"smdistributed" : {"dataparallel" : {"enabled" : True ,
124
+ "custom_mpi_options" : "options" }}}
124
125
125
126
126
127
class DummyFramework (Framework ):
@@ -3267,6 +3268,7 @@ def test_framework_distribution_configuration(sagemaker_session):
3267
3268
actual_ddp = framework ._distribution_configuration (distribution = DISTRIBUTION_SM_DDP_ENABLED )
3268
3269
expected_ddp = {
3269
3270
"sagemaker_distributed_dataparallel_enabled" : True ,
3271
+ "sagemaker_distributed_dataparallel_custom_mpi_options" : "options" ,
3270
3272
"sagemaker_instance_type" : INSTANCE_TYPE ,
3271
3273
}
3272
3274
assert actual_ddp == expected_ddp
You can’t perform that action at this time.
0 commit comments