Skip to content

Commit be9f546

Browse files
author
Yongyan Rao
committed
Add unit tests
1 parent 5207ddc commit be9f546

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tests/unit/test_fw_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ def cd(path):
4141
@pytest.fixture()
4242
def sagemaker_session():
4343
boto_mock = Mock(name="boto_session", region_name="us-west-2")
44+
boto_mock.client.describe_instance_types = Mock(
45+
return_value={
46+
"InstanceTypes": [
47+
{"GpuInfo": {}}
48+
]
49+
}
50+
)
4451
session_mock = Mock(
4552
name="sagemaker_session", boto_session=boto_mock, s3_client=None, s3_resource=None
4653
)
@@ -733,6 +740,24 @@ def test_validate_smdistributed_not_raises():
733740
)
734741

735742

743+
def test_validate_distribution_instance_no_smdistributed(sagemaker_session):
744+
distribution = {}
745+
instance_type = "mock_type"
746+
fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type)
747+
748+
749+
def test_validate_distribution_instance_no_modelparallel(sagemaker_session):
750+
distribution = {"smdistributed": {}}
751+
instance_type = "mock_type"
752+
fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type)
753+
754+
755+
def test_validate_distribution_instance_disabled_modelparallel(sagemaker_session):
756+
distribution = {"smdistributed": {"modelparallel": {"enabled": False}}}
757+
instance_type = "mock_type"
758+
fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type)
759+
760+
736761
def test_validate_smdistributed_raises():
737762
bad_args = [
738763
{"smdistributed": "dummy"},

0 commit comments

Comments
 (0)