Skip to content

Commit 20cd14b

Browse files
viclzhubhupendrasingh
authored andcommitted
change: Add 2.0-gpu-py310-cu121 to unsupported dlc images for the estimator.
Add additional unsupported image tests. Clean up tests.
1 parent 36bfae8 commit 20cd14b

File tree

2 files changed

+49
-24
lines changed

2 files changed

+49
-24
lines changed

src/sagemaker/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3198,6 +3198,7 @@ class Framework(EstimatorBase):
31983198
"""
31993199

32003200
_framework_name = None
3201+
UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM = ("2.0.1-gpu-py310-cu121", "2.0-gpu-py310-cu121")
32013202

32023203
def __init__(
32033204
self,
@@ -3415,7 +3416,6 @@ def __init__(
34153416
self.checkpoint_s3_uri = checkpoint_s3_uri
34163417
self.checkpoint_local_path = checkpoint_local_path
34173418
self.enable_sagemaker_metrics = enable_sagemaker_metrics
3418-
self.unsupported_dlc_image_for_sm_parallelism = ["2.0.1-gpu-py310-cu121"]
34193419

34203420
def _prepare_for_training(self, job_name=None):
34213421
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
@@ -3854,7 +3854,7 @@ def _distribution_configuration(self, distribution):
38543854
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
38553855
p5_enabled = bool("p5.48xlarge" in self.instance_type)
38563856
img_uri = "" if self.image_uri is None else self.image_uri
3857-
for unsupported_image in self.unsupported_dlc_image_for_sm_parallelism:
3857+
for unsupported_image in Framework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
38583858
if unsupported_image in img_uri and not torch_distributed_enabled: #disabling DLC images with CUDA12
38593859
raise ValueError(f"SMDistributed is currently incompatible with DLC image: {img_uri}. (Could be due to CUDA version being greater than 11.)")
38603860
if not torch_distributed_enabled and p5_enabled: #disabling p5 when torch distributed is disabled

tests/unit/test_estimator.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -322,35 +322,57 @@ def training_job_description(sagemaker_session):
322322
sagemaker_session.describe_training_job = mock_describe_training_job
323323
return returned_job_description
324324

325+
326+
def test_validate_smdistributed_unsupported_image_raises(sagemaker_session):
327+
# Test unsupported image raises error.
328+
for unsupported_image in DummyFramework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
329+
# Fail due to unsupported CUDA12 DLC image.
330+
f = DummyFramework(
331+
"some_script.py",
332+
role="DummyRole",
333+
instance_type="ml.p4d.24xlarge",
334+
sagemaker_session=sagemaker_session,
335+
output_path="outputpath",
336+
image_uri=unsupported_image,
337+
)
338+
with pytest.raises(ValueError):
339+
f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED)
340+
with pytest.raises(ValueError):
341+
f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED)
342+
343+
# Test unsupported image with suffix raises error.
344+
for unsupported_image in DummyFramework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
345+
# Fail due to unsupported CUDA12 DLC image.
346+
f = DummyFramework(
347+
"some_script.py",
348+
role="DummyRole",
349+
instance_type="ml.p4d.24xlarge",
350+
sagemaker_session=sagemaker_session,
351+
output_path="outputpath",
352+
image_uri=unsupported_image + "-ubuntu20.04-sagemaker-pr-3303",
353+
)
354+
with pytest.raises(ValueError):
355+
f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED)
356+
with pytest.raises(ValueError):
357+
f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED)
358+
359+
325360
def test_validate_smdistributed_p5_raises(sagemaker_session):
326-
# supported DLC image
361+
# Supported DLC image.
327362
f = DummyFramework(
328363
"some_script.py",
329364
role="DummyRole",
330365
instance_type="ml.p5.48xlarge",
331366
sagemaker_session=sagemaker_session,
332367
output_path="outputpath",
333-
image_uri="some_acceptable_image"
368+
image_uri="some_acceptable_image",
334369
)
335-
#both fail because instance type is p5 and torch_distributed is off
370+
# Both fail because instance type is p5 and torch_distributed is off.
336371
with pytest.raises(ValueError):
337372
f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED)
338373
with pytest.raises(ValueError):
339374
f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED)
340-
# unsupported DLC image
341-
f = DummyFramework(
342-
"some_script.py",
343-
role="DummyRole",
344-
instance_type="ml.p5.48xlarge",
345-
sagemaker_session=sagemaker_session,
346-
output_path="outputpath",
347-
image_uri="ecr-url/2.0.1-gpu-py310-cu121-ubuntu20.04-sagemaker-pr-3303"
348-
)
349-
#both fail due to unsupported CUDA12 DLC image
350-
with pytest.raises(ValueError):
351-
f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED)
352-
with pytest.raises(ValueError):
353-
f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED)
375+
354376

355377
def test_validate_smdistributed_p5_not_raises(sagemaker_session):
356378
f = DummyFramework(
@@ -359,20 +381,23 @@ def test_validate_smdistributed_p5_not_raises(sagemaker_session):
359381
instance_type="ml.p5.48xlarge",
360382
sagemaker_session=sagemaker_session,
361383
output_path="outputpath",
362-
image_uri="ecr-url/2.0.1-gpu-py310-cu121-ubuntu20.04-sagemaker-pr-3303"
384+
image_uri="ecr-url/2.0.1-gpu-py310-cu121-ubuntu20.04-sagemaker-pr-3303",
363385
)
364-
#testing with p5 instance and torch_distributed enabled
386+
# Testing with p5 instance and torch_distributed enabled.
365387
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED)
366388
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED)
389+
390+
391+
def test_validate_smdistributed_backward_compat_p4_not_raises(sagemaker_session):
367392
f = DummyFramework(
368393
"some_script.py",
369394
role="DummyRole",
370-
instance_type="ml.p4.24xlarge",
395+
instance_type="ml.p4d.24xlarge",
371396
sagemaker_session=sagemaker_session,
372397
output_path="outputpath",
373-
image_uri="some_acceptable_image"
398+
image_uri="some_acceptable_image",
374399
)
375-
#testing backwards compatability with p4d instances
400+
# Testing backwards compatability with p4d instances.
376401
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED)
377402
f._distribution_configuration(DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED)
378403

0 commit comments

Comments
 (0)