@@ -402,6 +402,39 @@ def test_validate_smdistributed_backward_compat_p4_not_raises(sagemaker_session)
402
402
f ._distribution_configuration (DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED )
403
403
404
404
405
+ def test_validate_smdistributed_instance_groups_raises (sagemaker_session ):
406
+ instance_group_1 = InstanceGroup ("train_group" , "ml.p4d.24xlarge" , 2 )
407
+ instance_group_2 = InstanceGroup ("train_group" , "ml.p5.48xlarge" , 2 )
408
+ f = DummyFramework (
409
+ "some_script.py" ,
410
+ role = "DummyRole" ,
411
+ instance_groups = [instance_group_1 , instance_group_2 ],
412
+ sagemaker_session = sagemaker_session ,
413
+ output_path = "outputpath" ,
414
+ image_uri = "some_acceptable_image" ,
415
+ )
416
+ # Testing instance_group with p5 raises exception
417
+ with pytest .raises (ValueError ):
418
+ f ._distribution_configuration (DISTRIBUTION_SM_DDP_ENABLED )
419
+ with pytest .raises (ValueError ):
420
+ f ._distribution_configuration (DISTRIBUTION_SM_DDP_DISABLED )
421
+
422
+
423
+ def test_validate_smdistributed_instance_groups_not_raises (sagemaker_session ):
424
+ instance_group_1 = InstanceGroup ("train_group" , "ml.p4d.24xlarge" , 2 )
425
+ f = DummyFramework (
426
+ "some_script.py" ,
427
+ role = "DummyRole" ,
428
+ instance_groups = [instance_group_1 ],
429
+ sagemaker_session = sagemaker_session ,
430
+ output_path = "outputpath" ,
431
+ image_uri = "some_acceptable_image" ,
432
+ )
433
+ # Testing instance_group without p5 does not raise exception
434
+ f ._distribution_configuration (DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED )
435
+ f ._distribution_configuration (DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED )
436
+
437
+
405
438
def test_framework_all_init_args (sagemaker_session ):
406
439
f = DummyFramework (
407
440
"my_script.py" ,
0 commit comments