@@ -322,35 +322,57 @@ def training_job_description(sagemaker_session):
322
322
sagemaker_session .describe_training_job = mock_describe_training_job
323
323
return returned_job_description
324
324
325
+
326
+ def test_validate_smdistributed_unsupported_image_raises (sagemaker_session ):
327
+ # Test unsupported image raises error.
328
+ for unsupported_image in DummyFramework .UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM :
329
+ # Fail due to unsupported CUDA12 DLC image.
330
+ f = DummyFramework (
331
+ "some_script.py" ,
332
+ role = "DummyRole" ,
333
+ instance_type = "ml.p4d.24xlarge" ,
334
+ sagemaker_session = sagemaker_session ,
335
+ output_path = "outputpath" ,
336
+ image_uri = unsupported_image ,
337
+ )
338
+ with pytest .raises (ValueError ):
339
+ f ._distribution_configuration (DISTRIBUTION_SM_DDP_ENABLED )
340
+ with pytest .raises (ValueError ):
341
+ f ._distribution_configuration (DISTRIBUTION_SM_DDP_DISABLED )
342
+
343
+ # Test unsupported image with suffix raises error.
344
+ for unsupported_image in DummyFramework .UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM :
345
+ # Fail due to unsupported CUDA12 DLC image.
346
+ f = DummyFramework (
347
+ "some_script.py" ,
348
+ role = "DummyRole" ,
349
+ instance_type = "ml.p4d.24xlarge" ,
350
+ sagemaker_session = sagemaker_session ,
351
+ output_path = "outputpath" ,
352
+ image_uri = unsupported_image + "-ubuntu20.04-sagemaker-pr-3303" ,
353
+ )
354
+ with pytest .raises (ValueError ):
355
+ f ._distribution_configuration (DISTRIBUTION_SM_DDP_ENABLED )
356
+ with pytest .raises (ValueError ):
357
+ f ._distribution_configuration (DISTRIBUTION_SM_DDP_DISABLED )
358
+
359
+
325
360
def test_validate_smdistributed_p5_raises (sagemaker_session ):
326
- # supported DLC image
361
+ # Supported DLC image.
327
362
f = DummyFramework (
328
363
"some_script.py" ,
329
364
role = "DummyRole" ,
330
365
instance_type = "ml.p5.48xlarge" ,
331
366
sagemaker_session = sagemaker_session ,
332
367
output_path = "outputpath" ,
333
- image_uri = "some_acceptable_image"
368
+ image_uri = "some_acceptable_image" ,
334
369
)
335
- #both fail because instance type is p5 and torch_distributed is off
370
+ # Both fail because instance type is p5 and torch_distributed is off.
336
371
with pytest .raises (ValueError ):
337
372
f ._distribution_configuration (DISTRIBUTION_SM_DDP_ENABLED )
338
373
with pytest .raises (ValueError ):
339
374
f ._distribution_configuration (DISTRIBUTION_SM_DDP_DISABLED )
340
- # unsupported DLC image
341
- f = DummyFramework (
342
- "some_script.py" ,
343
- role = "DummyRole" ,
344
- instance_type = "ml.p5.48xlarge" ,
345
- sagemaker_session = sagemaker_session ,
346
- output_path = "outputpath" ,
347
- image_uri = "ecr-url/2.0.1-gpu-py310-cu121-ubuntu20.04-sagemaker-pr-3303"
348
- )
349
- #both fail due to unsupported CUDA12 DLC image
350
- with pytest .raises (ValueError ):
351
- f ._distribution_configuration (DISTRIBUTION_SM_DDP_ENABLED )
352
- with pytest .raises (ValueError ):
353
- f ._distribution_configuration (DISTRIBUTION_SM_DDP_DISABLED )
375
+
354
376
355
377
def test_validate_smdistributed_p5_not_raises (sagemaker_session ):
356
378
f = DummyFramework (
@@ -359,20 +381,23 @@ def test_validate_smdistributed_p5_not_raises(sagemaker_session):
359
381
instance_type = "ml.p5.48xlarge" ,
360
382
sagemaker_session = sagemaker_session ,
361
383
output_path = "outputpath" ,
362
- image_uri = "ecr-url/2.0.1-gpu-py310-cu121-ubuntu20.04-sagemaker-pr-3303"
384
+ image_uri = "ecr-url/2.0.1-gpu-py310-cu121-ubuntu20.04-sagemaker-pr-3303" ,
363
385
)
364
- #testing with p5 instance and torch_distributed enabled
386
+ # Testing with p5 instance and torch_distributed enabled.
365
387
f ._distribution_configuration (DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED )
366
388
f ._distribution_configuration (DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED )
389
+
390
+
391
+ def test_validate_smdistributed_backward_compat_p4_not_raises (sagemaker_session ):
367
392
f = DummyFramework (
368
393
"some_script.py" ,
369
394
role = "DummyRole" ,
370
- instance_type = "ml.p4 .24xlarge" ,
395
+ instance_type = "ml.p4d .24xlarge" ,
371
396
sagemaker_session = sagemaker_session ,
372
397
output_path = "outputpath" ,
373
- image_uri = "some_acceptable_image"
398
+ image_uri = "some_acceptable_image" ,
374
399
)
375
- #testing backwards compatability with p4d instances
400
+ # Testing backwards compatability with p4d instances.
376
401
f ._distribution_configuration (DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED )
377
402
f ._distribution_configuration (DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED )
378
403
0 commit comments