File tree Expand file tree Collapse file tree 2 files changed +5
-6
lines changed Expand file tree Collapse file tree 2 files changed +5
-6
lines changed Original file line number Diff line number Diff line change @@ -3843,14 +3843,10 @@ def _distribution_configuration(self, distribution):
3843
3843
"custom_mpi_options" , ""
3844
3844
)
3845
3845
3846
- if get_mp_parameters (distribution ):
3847
- distribution_config ["mp_parameters" ] = get_mp_parameters (distribution )
3848
-
3849
- elif "modelparallel" in distribution .get ("smdistributed" , {}):
3850
- raise ValueError ("Cannot use Model Parallelism without MPI enabled!" )
3851
-
3852
3846
if "smdistributed" in distribution :
3853
3847
# smdistributed strategy selected
3848
+ if get_mp_parameters (distribution ):
3849
+ distribution_config ["mp_parameters" ] = get_mp_parameters (distribution )
3854
3850
smdistributed = distribution ["smdistributed" ]
3855
3851
smdataparallel_enabled = smdistributed .get ("dataparallel" , {}).get ("enabled" , False )
3856
3852
distribution_config [self .LAUNCH_SM_DDP_ENV_NAME ] = smdataparallel_enabled
Original file line number Diff line number Diff line change @@ -326,6 +326,9 @@ def _pytorch_distribution_configuration(self, distribution):
326
326
if self .instance_type is not None :
327
327
distribution_config [self .INSTANCE_TYPE_ENV_NAME ] = self .instance_type
328
328
elif torch_distributed_enabled :
329
+ if "smdistributed" in distribution :
330
+ # Enable torch_distributed for smdistributed.
331
+ distribution_config = self ._distribution_configuration (distribution = distribution )
329
332
distribution_config [self .LAUNCH_TORCH_DISTRIBUTED_ENV_NAME ] = torch_distributed_enabled
330
333
if self .instance_type is not None :
331
334
distribution_config [self .INSTANCE_TYPE_ENV_NAME ] = self .instance_type
You can’t perform that action at this time.
0 commit comments