Skip to content

Commit 7d67d89

Browse files
rnadimpgoelakash
authored andcommitted
deprecation: Blocking submission of distribution configs that try to use CUDA 12.1 enabled DLC containers and/or p5 instance with smdistributed enabled and torch-distributed disabled.
1 parent 65c8c4e commit 7d67d89

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

src/sagemaker/estimator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3414,6 +3414,7 @@ def __init__(
34143414
self.checkpoint_s3_uri = checkpoint_s3_uri
34153415
self.checkpoint_local_path = checkpoint_local_path
34163416
self.enable_sagemaker_metrics = enable_sagemaker_metrics
3417+
self.unsupported_dlc_image_for_sm_parallelism = ["2.0.1-gpu-py310-cu121"]
34173418

34183419
def _prepare_for_training(self, job_name=None):
34193420
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
@@ -3846,8 +3847,18 @@ def _distribution_configuration(self, distribution):
38463847
# smdistributed strategy selected
38473848
if get_mp_parameters(distribution):
38483849
distribution_config["mp_parameters"] = get_mp_parameters(distribution)
3850+
# first make sure torch_distributed is enabled if instance type is p5
3851+
torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)
38493852
smdistributed = distribution["smdistributed"]
38503853
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
3854+
p5_enabled = bool("p5.48xlarge" in self.instance_type)
3855+
img_uri = "" if self.image_uri is None else self.image_uri
3856+
for unsupported_image in self.unsupported_dlc_image_for_sm_parallelism:
3857+
if unsupported_image in img_uri and not torch_distributed_enabled: #disabling DLC images with CUDA12
3858+
raise ValueError(f"SMDistributed is currently incompatible with DLC image: {img_uri}. (Could be due to CUDA version being greater than 11.)")
3859+
if not torch_distributed_enabled and p5_enabled: #disabling p5 when torch distributed is disabled
3860+
raise ValueError("SMModelParallel and SMDataParallel currently do not support p5 instances.")
3861+
# smdistributed strategy selected with supported instance type
38513862
distribution_config[self.LAUNCH_SM_DDP_ENV_NAME] = smdataparallel_enabled
38523863
distribution_config[self.INSTANCE_TYPE] = self.instance_type
38533864
if smdataparallel_enabled:

tests/unit/test_estimator.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,20 @@
169169
"mpi": {"enabled": True, "custom_mpi_options": "options", "processes_per_host": 2}
170170
}
171171
DISTRIBUTION_SM_DDP_ENABLED = {
172-
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "options"}}
172+
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "options"}},
173+
"torch_distributed": {"enabled": False}
174+
}
175+
DISTRIBUTION_SM_DDP_DISABLED = {
176+
"smdistributed": {"enabled": True},
177+
"torch_distributed": {"enabled": False}
178+
}
179+
DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED = {
180+
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "options"}},
181+
"torch_distributed": {"enabled": True}
182+
}
183+
DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED = {
184+
"smdistributed": {"enabled": True},
185+
"torch_distributed": {"enabled": True}
173186
}
174187
MOCKED_S3_URI = "s3://mocked_s3_uri_from_source_dir"
175188
_DEFINITION_CONFIG = PipelineDefinitionConfig(use_custom_job_prefix=False)
@@ -309,6 +322,60 @@ def training_job_description(sagemaker_session):
309322
sagemaker_session.describe_training_job = mock_describe_training_job
310323
return returned_job_description
311324

325+
def test_validate_smdistributed_p5_raises(sagemaker_session):
326+
# supported DLC image
327+
f = DummyFramework(
328+
"some_script.py",
329+
role="DummyRole",
330+
instance_type="ml.p5.48xlarge",
331+
sagemaker_session=sagemaker_session,
332+
output_path="outputpath",
333+
image_uri="some_acceptable_image"
334+
)
335+
#both fail because instance type is p5 and torch_distributed is off
336+
with pytest.raises(ValueError):
337+
f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED)
338+
with pytest.raises(ValueError):
339+
f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED)
340+
# unsupported DLC image
341+
f = DummyFramework(
342+
"some_script.py",
343+
role="DummyRole",
344+
instance_type="ml.p5.48xlarge",
345+
sagemaker_session=sagemaker_session,
346+
output_path="outputpath",
347+
image_uri="ecr-url/2.0.1-gpu-py310-cu121-ubuntu20.04-sagemaker-pr-3303"
348+
)
349+
#both fail due to unsupported CUDA12 DLC image
350+
with pytest.raises(ValueError):
351+
f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED)
352+
with pytest.raises(ValueError):
353+
f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED)
354+
355+
def test_validate_smdistributed_p5_not_raises(sagemaker_session):
356+
f = DummyFramework(
357+
"some_script.py",
358+
role="DummyRole",
359+
instance_type="ml.p5.48xlarge",
360+
sagemaker_session=sagemaker_session,
361+
output_path="outputpath",
362+
image_uri="ecr-url/2.0.1-gpu-py310-cu121-ubuntu20.04-sagemaker-pr-3303"
363+
)
364+
#testing with p5 instance and torch_distributed enabled
365+
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED)
366+
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED)
367+
f = DummyFramework(
368+
"some_script.py",
369+
role="DummyRole",
370+
instance_type="ml.p4.24xlarge",
371+
sagemaker_session=sagemaker_session,
372+
output_path="outputpath",
373+
image_uri="some_acceptable_image"
374+
)
375+
#testing backwards compatability with p4d instances
376+
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED)
377+
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED)
378+
312379

313380
def test_framework_all_init_args(sagemaker_session):
314381
f = DummyFramework(

0 commit comments

Comments
 (0)