@@ -49,6 +49,11 @@ def sagemaker_session():
49
49
session_mock .sagemaker_client .describe_training_job = Mock (
50
50
return_value = {"ModelArtifacts" : {"S3ModelArtifacts" : "s3://m/m.tar.gz" }}
51
51
)
52
+ session_mock .boto_session .client ("ec2" ).describe_instance_types = Mock (
53
+ return_value = {
54
+ "InstanceTypes" : [{"CpuInfo" : {},},],
55
+ }
56
+ )
52
57
return session_mock
53
58
54
59
@@ -733,6 +738,31 @@ def test_validate_smdistributed_not_raises():
733
738
)
734
739
735
740
741
+ def test_validate_distribution_instance_no_smdistributed (sagemaker_session ):
742
+ distribution = {}
743
+ instance_type = "mock_type"
744
+ fw_utils .validate_distribution_instance (sagemaker_session , distribution , instance_type )
745
+
746
+
747
+ def test_validate_distribution_instance_no_modelparallel (sagemaker_session ):
748
+ distribution = {"smdistributed" : {}}
749
+ instance_type = "mock_type"
750
+ fw_utils .validate_distribution_instance (sagemaker_session , distribution , instance_type )
751
+
752
+
753
+ def test_validate_distribution_instance_disabled_modelparallel (sagemaker_session ):
754
+ distribution = {"smdistributed" : {"modelparallel" : {"enabled" : False }}}
755
+ instance_type = "mock_type"
756
+ fw_utils .validate_distribution_instance (sagemaker_session , distribution , instance_type )
757
+
758
+
759
+ def test_validate_distribution_instance_raise (sagemaker_session ):
760
+ distribution = {"smdistributed" : {"modelparallel" : {"enabled" : True }}}
761
+ instance_type = "mock_type"
762
+ with pytest .raises (ValueError ):
763
+ fw_utils .validate_distribution_instance (sagemaker_session , distribution , instance_type )
764
+
765
+
736
766
def test_validate_smdistributed_raises ():
737
767
bad_args = [
738
768
{"smdistributed" : "dummy" },
0 commit comments