Skip to content

Commit 7332b6f

Browse files
committed
change: allow smdistributed to be enabled with torch_distributed.
1 parent 64b2f47 commit 7332b6f

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

src/sagemaker/estimator.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3842,14 +3842,10 @@ def _distribution_configuration(self, distribution):
38423842
"custom_mpi_options", ""
38433843
)
38443844

3845-
if get_mp_parameters(distribution):
3846-
distribution_config["mp_parameters"] = get_mp_parameters(distribution)
3847-
3848-
elif "modelparallel" in distribution.get("smdistributed", {}):
3849-
raise ValueError("Cannot use Model Parallelism without MPI enabled!")
3850-
38513845
if "smdistributed" in distribution:
38523846
# smdistributed strategy selected
3847+
if get_mp_parameters(distribution):
3848+
distribution_config["mp_parameters"] = get_mp_parameters(distribution)
38533849
smdistributed = distribution["smdistributed"]
38543850
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
38553851
distribution_config[self.LAUNCH_SM_DDP_ENV_NAME] = smdataparallel_enabled

src/sagemaker/pytorch/estimator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,9 @@ def _pytorch_distribution_configuration(self, distribution):
326326
if self.instance_type is not None:
327327
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
328328
elif torch_distributed_enabled:
329+
if "smdistributed" in distribution:
330+
# Enable torch_distributed for smdistributed.
331+
distribution_config = self._distribution_configuration(distribution=distribution)
329332
distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled
330333
if self.instance_type is not None:
331334
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type

0 commit comments

Comments
 (0)