File tree Expand file tree Collapse file tree 2 files changed +5
-6
lines changed Expand file tree Collapse file tree 2 files changed +5
-6
lines changed Original file line number Diff line number Diff line change @@ -3842,14 +3842,10 @@ def _distribution_configuration(self, distribution):
3842
3842
"custom_mpi_options" , ""
3843
3843
)
3844
3844
3845
- if get_mp_parameters (distribution ):
3846
- distribution_config ["mp_parameters" ] = get_mp_parameters (distribution )
3847
-
3848
- elif "modelparallel" in distribution .get ("smdistributed" , {}):
3849
- raise ValueError ("Cannot use Model Parallelism without MPI enabled!" )
3850
-
3851
3845
if "smdistributed" in distribution :
3852
3846
# smdistributed strategy selected
3847
+ if get_mp_parameters (distribution ):
3848
+ distribution_config ["mp_parameters" ] = get_mp_parameters (distribution )
3853
3849
smdistributed = distribution ["smdistributed" ]
3854
3850
smdataparallel_enabled = smdistributed .get ("dataparallel" , {}).get ("enabled" , False )
3855
3851
distribution_config [self .LAUNCH_SM_DDP_ENV_NAME ] = smdataparallel_enabled
Original file line number Diff line number Diff line change @@ -326,6 +326,9 @@ def _pytorch_distribution_configuration(self, distribution):
326
326
if self .instance_type is not None :
327
327
distribution_config [self .INSTANCE_TYPE_ENV_NAME ] = self .instance_type
328
328
elif torch_distributed_enabled :
329
+ if "smdistributed" in distribution :
330
+ # Enable torch_distributed for smdistributed.
331
+ distribution_config = self ._distribution_configuration (distribution = distribution )
329
332
distribution_config [self .LAUNCH_TORCH_DISTRIBUTED_ENV_NAME ] = torch_distributed_enabled
330
333
if self .instance_type is not None :
331
334
distribution_config [self .INSTANCE_TYPE_ENV_NAME ] = self .instance_type
You can’t perform that action at this time.
0 commit comments