Skip to content

Commit 4e6f2e0

Browse files
viclzhurnadimprohithn1
authored
change: allow smdistributed to be enabled with torch_distributed. (aws#4129)
Co-authored-by: Rohith Nadimpally <[email protected]> Co-authored-by: rohithn1 <[email protected]>
1 parent 6427fda commit 4e6f2e0

File tree

3 files changed

+131
-7
lines changed

3 files changed

+131
-7
lines changed

src/sagemaker/estimator.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
)
102102
from sagemaker.workflow import is_pipeline_variable
103103
from sagemaker.workflow.entities import PipelineVariable
104+
from sagemaker.workflow.parameters import ParameterString
104105
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
105106

106107
logger = logging.getLogger(__name__)
@@ -3198,6 +3199,7 @@ class Framework(EstimatorBase):
31983199
"""
31993200

32003201
_framework_name = None
3202+
UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM = ("2.0.1-gpu-py310-cu121", "2.0-gpu-py310-cu121")
32013203

32023204
def __init__(
32033205
self,
@@ -3843,16 +3845,43 @@ def _distribution_configuration(self, distribution):
38433845
"custom_mpi_options", ""
38443846
)
38453847

3846-
if get_mp_parameters(distribution):
3847-
distribution_config["mp_parameters"] = get_mp_parameters(distribution)
3848-
3849-
elif "modelparallel" in distribution.get("smdistributed", {}):
3850-
raise ValueError("Cannot use Model Parallelism without MPI enabled!")
3851-
38523848
if "smdistributed" in distribution:
38533849
# smdistributed strategy selected
3850+
if get_mp_parameters(distribution):
3851+
distribution_config["mp_parameters"] = get_mp_parameters(distribution)
3852+
# first make sure torch_distributed is enabled if instance type is p5
3853+
torch_distributed_enabled = False
3854+
if "torch_distributed" in distribution:
3855+
torch_distributed_enabled = distribution.get("torch_distributed").get(
3856+
"enabled", False
3857+
)
38543858
smdistributed = distribution["smdistributed"]
38553859
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
3860+
if isinstance(self.instance_type, ParameterString):
3861+
p5_enabled = "p5.48xlarge" in self.instance_type.default_value
3862+
elif isinstance(self.instance_type, str):
3863+
p5_enabled = "p5.48xlarge" in self.instance_type
3864+
else:
3865+
raise ValueError(
3866+
"Invalid object type for instance_type argument. Expected "
3867+
f"{type(str)} or {type(ParameterString)} but got {type(self.instance_type)}."
3868+
)
3869+
img_uri = "" if self.image_uri is None else self.image_uri
3870+
for unsupported_image in Framework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
3871+
if (
3872+
unsupported_image in img_uri and not torch_distributed_enabled
3873+
): # disabling DLC images with CUDA12
3874+
raise ValueError(
3875+
f"SMDistributed is currently incompatible with DLC image: {img_uri}. "
3876+
"(Could be due to CUDA version being greater than 11.)"
3877+
)
3878+
if (
3879+
not torch_distributed_enabled and p5_enabled
3880+
): # disabling p5 when torch distributed is disabled
3881+
raise ValueError(
3882+
"SMModelParallel and SMDataParallel currently do not support p5 instances."
3883+
)
3884+
# smdistributed strategy selected with supported instance type
38563885
distribution_config[self.LAUNCH_SM_DDP_ENV_NAME] = smdataparallel_enabled
38573886
distribution_config[self.INSTANCE_TYPE] = self.instance_type
38583887
if smdataparallel_enabled:

src/sagemaker/pytorch/estimator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,9 @@ def _pytorch_distribution_configuration(self, distribution):
326326
if self.instance_type is not None:
327327
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type
328328
elif torch_distributed_enabled:
329+
if "smdistributed" in distribution:
330+
# Enable torch_distributed for smdistributed.
331+
distribution_config = self._distribution_configuration(distribution=distribution)
329332
distribution_config[self.LAUNCH_TORCH_DISTRIBUTED_ENV_NAME] = torch_distributed_enabled
330333
if self.instance_type is not None:
331334
distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type

tests/unit/test_estimator.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,20 @@
169169
"mpi": {"enabled": True, "custom_mpi_options": "options", "processes_per_host": 2}
170170
}
171171
DISTRIBUTION_SM_DDP_ENABLED = {
172-
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "options"}}
172+
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "options"}},
173+
"torch_distributed": {"enabled": False},
174+
}
175+
DISTRIBUTION_SM_DDP_DISABLED = {
176+
"smdistributed": {"enabled": True},
177+
"torch_distributed": {"enabled": False},
178+
}
179+
DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED = {
180+
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "options"}},
181+
"torch_distributed": {"enabled": True},
182+
}
183+
DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED = {
184+
"smdistributed": {"enabled": True},
185+
"torch_distributed": {"enabled": True},
173186
}
174187
MOCKED_S3_URI = "s3://mocked_s3_uri_from_source_dir"
175188
_DEFINITION_CONFIG = PipelineDefinitionConfig(use_custom_job_prefix=False)
@@ -310,6 +323,85 @@ def training_job_description(sagemaker_session):
310323
return returned_job_description
311324

312325

326+
def test_validate_smdistributed_unsupported_image_raises(sagemaker_session):
327+
# Test unsupported image raises error.
328+
for unsupported_image in DummyFramework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
329+
# Fail due to unsupported CUDA12 DLC image.
330+
f = DummyFramework(
331+
"some_script.py",
332+
role="DummyRole",
333+
instance_type="ml.p4d.24xlarge",
334+
sagemaker_session=sagemaker_session,
335+
output_path="outputpath",
336+
image_uri=unsupported_image,
337+
)
338+
with pytest.raises(ValueError):
339+
f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED)
340+
with pytest.raises(ValueError):
341+
f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED)
342+
343+
# Test unsupported image with suffix raises error.
344+
for unsupported_image in DummyFramework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
345+
# Fail due to unsupported CUDA12 DLC image.
346+
f = DummyFramework(
347+
"some_script.py",
348+
role="DummyRole",
349+
instance_type="ml.p4d.24xlarge",
350+
sagemaker_session=sagemaker_session,
351+
output_path="outputpath",
352+
image_uri=unsupported_image + "-ubuntu20.04-sagemaker-pr-3303",
353+
)
354+
with pytest.raises(ValueError):
355+
f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED)
356+
with pytest.raises(ValueError):
357+
f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED)
358+
359+
360+
def test_validate_smdistributed_p5_raises(sagemaker_session):
361+
# Supported DLC image.
362+
f = DummyFramework(
363+
"some_script.py",
364+
role="DummyRole",
365+
instance_type="ml.p5.48xlarge",
366+
sagemaker_session=sagemaker_session,
367+
output_path="outputpath",
368+
image_uri="some_acceptable_image",
369+
)
370+
# Both fail because instance type is p5 and torch_distributed is off.
371+
with pytest.raises(ValueError):
372+
f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED)
373+
with pytest.raises(ValueError):
374+
f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED)
375+
376+
377+
def test_validate_smdistributed_p5_not_raises(sagemaker_session):
378+
f = DummyFramework(
379+
"some_script.py",
380+
role="DummyRole",
381+
instance_type="ml.p5.48xlarge",
382+
sagemaker_session=sagemaker_session,
383+
output_path="outputpath",
384+
image_uri="ecr-url/2.0.1-gpu-py310-cu121-ubuntu20.04-sagemaker-pr-3303",
385+
)
386+
# Testing with p5 instance and torch_distributed enabled.
387+
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED)
388+
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED)
389+
390+
391+
def test_validate_smdistributed_backward_compat_p4_not_raises(sagemaker_session):
392+
f = DummyFramework(
393+
"some_script.py",
394+
role="DummyRole",
395+
instance_type="ml.p4d.24xlarge",
396+
sagemaker_session=sagemaker_session,
397+
output_path="outputpath",
398+
image_uri="some_acceptable_image",
399+
)
400+
# Testing backwards compatability with p4d instances.
401+
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED)
402+
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED)
403+
404+
313405
def test_framework_all_init_args(sagemaker_session):
314406
f = DummyFramework(
315407
"my_script.py",

0 commit comments

Comments
 (0)