Skip to content

Commit 1506147

Browse files
committed
add unittests
1 parent b600dd1 commit 1506147

File tree

4 files changed

+171
-8
lines changed

4 files changed

+171
-8
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def _retrieval_function(
368368
hub_description: DescribeHubResponse = hub.describe()
369369
return JumpStartCachedContentValue(formatted_content=hub_description)
370370
raise ValueError(
371-
f"Bad value for key '{key}': must be in",
371+
f"Bad value for key '{key}': must be in ",
372372
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubContentType.HUB, HubContentType.MODEL]}"
373373
)
374374

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,20 @@ def __init__(
3434
region: str,
3535
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3636
):
37+
"""Instantiates a SageMaker ``CuratedHub``.
38+
39+
Args:
40+
hub_name (str): The name of the Hub to create.
41+
region (str): The region in which the CuratedHub is in.
42+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
43+
object, used for SageMaker interactions.
44+
"""
3745
self.hub_name = hub_name
3846
if sagemaker_session.boto_region_name != region:
39-
# TODO: Handle error
40-
pass
47+
raise ValueError(
48+
f"Cannot have conflicting regions for region=[{region}] and ",
49+
f"sagemaker_session region=[{str(sagemaker_session.boto_region_name)}].",
50+
)
4151
self.region = region
4252
self._sagemaker_session = sagemaker_session
4353

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
import pytest
15+
from mock import Mock
16+
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub
17+
18+
REGION = "us-east-1"
19+
ACCOUNT_ID = "123456789123"
20+
HUB_NAME = "mock-hub-name"
21+
22+
23+
@pytest.fixture()
24+
def sagemaker_session():
25+
boto_mock = Mock(name="boto_session")
26+
sagemaker_session_mock = Mock(
27+
name="sagemaker_session", boto_session=boto_mock, boto_region_name=REGION
28+
)
29+
sagemaker_session_mock._client_config.user_agent = (
30+
"Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource"
31+
)
32+
sagemaker_session_mock.account_id.return_value = ACCOUNT_ID
33+
return sagemaker_session_mock
34+
35+
36+
# @pytest.fixture()
37+
# def sagemaker_session():
38+
# boto_mock = Mock(name="boto_session", region_name=REGION)
39+
# session_mock = Mock(
40+
# name="sagemaker_session",
41+
# boto_session=boto_mock,
42+
# boto_region_name=REGION,
43+
# config=None,
44+
# local_mode=False,
45+
# default_bucket_prefix=None,
46+
# )
47+
# session_mock.return_value.sagemkaer_client = Mock(name="sagemaker_client")
48+
# session_mock.sts_client.get_caller_identity = Mock(return_value={"Account": ACCOUNT_ID})
49+
# create_hub = {"HubArn": "arn:aws:sagemaker:us-east-1:123456789123:hub/mock-hub-name"}
50+
# session_mock.sagemaker_client.create_hub = Mock(return_value=create_hub)
51+
# print(session_mock.sagemaker_client)
52+
# return session_mock
53+
54+
55+
def test_instantiates(sagemaker_session):
56+
hub = CuratedHub(hub_name=HUB_NAME, region=REGION, sagemaker_session=sagemaker_session)
57+
assert hub.hub_name == HUB_NAME
58+
assert hub.region == "us-east-1"
59+
assert hub._sagemaker_session == sagemaker_session
60+
61+
62+
def test_instantiates_handles_conflicting_regions(sagemaker_session):
63+
conflicting_region = "us-east-2"
64+
65+
with pytest.raises(ValueError):
66+
CuratedHub(
67+
hub_name=HUB_NAME, region=conflicting_region, sagemaker_session=sagemaker_session
68+
)
69+
70+
71+
@pytest.mark.parametrize(
72+
("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"),
73+
[
74+
pytest.param("MockHub1", "this is my sagemaker hub", None, None, None, None),
75+
pytest.param(
76+
"MockHub2",
77+
"this is my sagemaker hub two",
78+
None,
79+
"DisplayMockHub2",
80+
["mock", "hub", "123"],
81+
[{"Key": "tag-key-1", "Value": "tag-value-1"}],
82+
),
83+
],
84+
)
85+
def test_create_with_no_bucket_name(
86+
sagemaker_session,
87+
hub_name,
88+
hub_description,
89+
hub_bucket_name,
90+
hub_display_name,
91+
hub_search_keywords,
92+
tags,
93+
):
94+
create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
95+
sagemaker_session.create_hub = Mock(return_value=create_hub)
96+
hub = CuratedHub(hub_name=hub_name, region=REGION, sagemaker_session=sagemaker_session)
97+
request = {
98+
"hub_name": hub_name,
99+
"hub_description": hub_description,
100+
"hub_bucket_name": "sagemaker-hubs-us-east-1-123456789123",
101+
"hub_display_name": hub_display_name,
102+
"hub_search_keywords": hub_search_keywords,
103+
"tags": tags,
104+
}
105+
response = hub.create(
106+
description=hub_description,
107+
display_name=hub_display_name,
108+
bucket_name=hub_bucket_name,
109+
search_keywords=hub_search_keywords,
110+
tags=tags,
111+
)
112+
sagemaker_session.create_hub.assert_called_with(**request)
113+
assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
114+
115+
116+
@pytest.mark.parametrize(
117+
("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"),
118+
[
119+
pytest.param("MockHub1", "this is my sagemaker hub", "mock-bucket-123", None, None, None),
120+
pytest.param(
121+
"MockHub2",
122+
"this is my sagemaker hub two",
123+
"mock-bucket-123",
124+
"DisplayMockHub2",
125+
["mock", "hub", "123"],
126+
[{"Key": "tag-key-1", "Value": "tag-value-1"}],
127+
),
128+
],
129+
)
130+
def test_create_with_bucket_name(
131+
sagemaker_session,
132+
hub_name,
133+
hub_description,
134+
hub_bucket_name,
135+
hub_display_name,
136+
hub_search_keywords,
137+
tags,
138+
):
139+
create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
140+
sagemaker_session.create_hub = Mock(return_value=create_hub)
141+
hub = CuratedHub(hub_name=hub_name, region=REGION, sagemaker_session=sagemaker_session)
142+
request = {
143+
"hub_name": hub_name,
144+
"hub_description": hub_description,
145+
"hub_bucket_name": hub_bucket_name,
146+
"hub_display_name": hub_display_name,
147+
"hub_search_keywords": hub_search_keywords,
148+
"tags": tags,
149+
}
150+
response = hub.create(
151+
description=hub_description,
152+
display_name=hub_display_name,
153+
bucket_name=hub_bucket_name,
154+
search_keywords=hub_search_keywords,
155+
tags=tags,
156+
)
157+
sagemaker_session.create_hub.assert_called_with(**request)
158+
assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from __future__ import absolute_import
1414

1515
from unittest.mock import Mock
16-
from botocore.exceptions import ClientError
1716
from sagemaker.jumpstart.curated_hub import utils
1817
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
1918
from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo
@@ -173,10 +172,6 @@ def test_create_hub_bucket_if_it_does_not_exist():
173172
}
174173
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
175174
mock_sagemaker_session.boto_region_name = "us-east-1"
176-
error = ClientError(
177-
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
178-
operation_name="foo",
179-
)
180175
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
181176
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
182177
sagemaker_session=mock_sagemaker_session

0 commit comments

Comments
 (0)