File tree Expand file tree Collapse file tree 2 files changed +5
-6
lines changed Expand file tree Collapse file tree 2 files changed +5
-6
lines changed Original file line number Diff line number Diff line change @@ -3822,14 +3822,10 @@ def _distribution_configuration(self, distribution):
3822
3822
"custom_mpi_options" , ""
3823
3823
)
3824
3824
3825
- if get_mp_parameters (distribution ):
3826
- distribution_config ["mp_parameters" ] = get_mp_parameters (distribution )
3827
-
3828
- elif "modelparallel" in distribution .get ("smdistributed" , {}):
3829
- raise ValueError ("Cannot use Model Parallelism without MPI enabled!" )
3830
-
3831
3825
if "smdistributed" in distribution :
3832
3826
# smdistributed strategy selected
3827
+ if get_mp_parameters (distribution ):
3828
+ distribution_config ["mp_parameters" ] = get_mp_parameters (distribution )
3833
3829
smdistributed = distribution ["smdistributed" ]
3834
3830
smdataparallel_enabled = smdistributed .get ("dataparallel" , {}).get ("enabled" , False )
3835
3831
distribution_config [self .LAUNCH_SM_DDP_ENV_NAME ] = smdataparallel_enabled
Original file line number Diff line number Diff line change @@ -326,6 +326,9 @@ def _pytorch_distribution_configuration(self, distribution):
326
326
if self .instance_type is not None :
327
327
distribution_config [self .INSTANCE_TYPE_ENV_NAME ] = self .instance_type
328
328
elif torch_distributed_enabled :
329
+ if "smdistributed" in distribution :
330
+ # Enable torch_distributed for smdistributed.
331
+ distribution_config = self ._distribution_configuration (distribution = distribution )
329
332
distribution_config [self .LAUNCH_TORCH_DISTRIBUTED_ENV_NAME ] = torch_distributed_enabled
330
333
if self .instance_type is not None :
331
334
distribution_config [self .INSTANCE_TYPE_ENV_NAME ] = self .instance_type
You can’t perform that action at this time.
0 commit comments