File tree Expand file tree Collapse file tree 2 files changed +8
-1
lines changed Expand file tree Collapse file tree 2 files changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -1954,6 +1954,7 @@ class Framework(EstimatorBase):
1954
1954
INSTANCE_TYPE = "sagemaker_instance_type"
1955
1955
MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host"
1956
1956
MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options"
1957
+ SM_DDP_CUSTOM_MPI_OPTIONS = "sagemaker_distributed_dataparallel_custom_mpi_options"
1957
1958
CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz"
1958
1959
1959
1960
def __init__ (
@@ -2629,6 +2630,10 @@ def _distribution_configuration(self, distribution):
2629
2630
smdataparallel_enabled = smdistributed .get ("dataparallel" , {}).get ("enabled" , False )
2630
2631
distribution_config [self .LAUNCH_SM_DDP_ENV_NAME ] = smdataparallel_enabled
2631
2632
distribution_config [self .INSTANCE_TYPE ] = self .instance_type
2633
+ if smdataparallel_enabled :
2634
+ distribution_config [self .SM_DDP_CUSTOM_MPI_OPTIONS ] = smdistributed ["dataparallel" ].get (
2635
+ "custom_mpi_options" , ""
2636
+ )
2632
2637
2633
2638
return distribution_config
2634
2639
Original file line number Diff line number Diff line change 121
121
DISTRIBUTION_MPI_ENABLED = {
122
122
"mpi" : {"enabled" : True , "custom_mpi_options" : "options" , "processes_per_host" : 2 }
123
123
}
124
- DISTRIBUTION_SM_DDP_ENABLED = {"smdistributed" : {"dataparallel" : {"enabled" : True }}}
124
+ DISTRIBUTION_SM_DDP_ENABLED = {"smdistributed" : {"dataparallel" : {"enabled" : True ,
125
+ "custom_mpi_options" : "options" }}}
125
126
126
127
127
128
class DummyFramework (Framework ):
@@ -3290,6 +3291,7 @@ def test_framework_distribution_configuration(sagemaker_session):
3290
3291
actual_ddp = framework ._distribution_configuration (distribution = DISTRIBUTION_SM_DDP_ENABLED )
3291
3292
expected_ddp = {
3292
3293
"sagemaker_distributed_dataparallel_enabled" : True ,
3294
+ "sagemaker_distributed_dataparallel_custom_mpi_options" : "options" ,
3293
3295
"sagemaker_instance_type" : INSTANCE_TYPE ,
3294
3296
}
3295
3297
assert actual_ddp == expected_ddp
You can’t perform that action at this time.
0 commit comments