File tree Expand file tree Collapse file tree 1 file changed +25
-0
lines changed Expand file tree Collapse file tree 1 file changed +25
-0
lines changed Original file line number Diff line number Diff line change @@ -41,6 +41,13 @@ def cd(path):
41
41
@pytest .fixture ()
42
42
def sagemaker_session ():
43
43
boto_mock = Mock (name = "boto_session" , region_name = "us-west-2" )
44
+ boto_mock .client .describe_instance_types = Mock (
45
+ return_value = {
46
+ "InstanceTypes" : [
47
+ {"GpuInfo" : {}}
48
+ ]
49
+ }
50
+ )
44
51
session_mock = Mock (
45
52
name = "sagemaker_session" , boto_session = boto_mock , s3_client = None , s3_resource = None
46
53
)
@@ -733,6 +740,24 @@ def test_validate_smdistributed_not_raises():
733
740
)
734
741
735
742
743
+ def test_validate_distribution_instance_no_smdistributed (sagemaker_session ):
744
+ distribution = {}
745
+ instance_type = "mock_type"
746
+ fw_utils .validate_distribution_instance (sagemaker_session , distribution , instance_type )
747
+
748
+
749
+ def test_validate_distribution_instance_no_modelparallel (sagemaker_session ):
750
+ distribution = {"smdistributed" : {}}
751
+ instance_type = "mock_type"
752
+ fw_utils .validate_distribution_instance (sagemaker_session , distribution , instance_type )
753
+
754
+
755
+ def test_validate_distribution_instance_disabled_modelparallel (sagemaker_session ):
756
+ distribution = {"smdistributed" : {"modelparallel" : {"enabled" : False }}}
757
+ instance_type = "mock_type"
758
+ fw_utils .validate_distribution_instance (sagemaker_session , distribution , instance_type )
759
+
760
+
736
761
def test_validate_smdistributed_raises ():
737
762
bad_args = [
738
763
{"smdistributed" : "dummy" },
You can’t perform that action at this time.
0 commit comments