Skip to content

Commit 463f112

Browse files
viclzhubhupendrasingh
authored andcommitted
change: allow smdistributed to be enabled with torch_distributed.
1 parent 488adba commit 463f112

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

src/sagemaker/estimator.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3843,14 +3843,10 @@ def _distribution_configuration(self, distribution):
38433843
"custom_mpi_options", ""
38443844
)
38453845

3846-
if get_mp_parameters(distribution):
3847-
distribution_config["mp_parameters"] = get_mp_parameters(distribution)
3848-
3849-
elif "modelparallel" in distribution.get("smdistributed", {}):
3850-
raise ValueError("Cannot use Model Parallelism without MPI enabled!")
3851-
38523846
if "smdistributed" in distribution:
38533847
# smdistributed strategy selected
3848+
if get_mp_parameters(distribution):
3849+
distribution_config["mp_parameters"] = get_mp_parameters(distribution)
38543850
smdistributed = distribution["smdistributed"]
38553851
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
38563852
distribution_config[self.LAUNCH_SM_DDP_ENV_NAME] = smdataparallel_enabled

src/sagemaker/pytorch/estimator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,9 @@ def _pytorch_distribution_configuration(self, distribution):
326326
if self.instance_type is not None:
327327
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
328328
elif torch_distributed_enabled:
329+
if "smdistributed" in distribution:
330+
# Enable torch_distributed for smdistributed.
331+
distribution_config = self._distribution_configuration(distribution=distribution)
329332
distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled
330333
if self.instance_type is not None:
331334
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type

0 commit comments

Comments
 (0)