Skip to content

Commit 0947840

Browse files
committed
fix: Adding unit tests
1 parent 48d6325 commit 0947840

File tree

3 files changed

+124
-8
lines changed

3 files changed

+124
-8
lines changed

src/sagemaker/jumpstart/curated_hub/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class CuratedHubTagName(str, Enum):
2727

2828
@dataclass
2929
class Tag:
30-
key: str
30+
key: CuratedHubTagName
3131
value: str
3232

3333
@dataclass

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,18 @@ def tag_hub_content(hub_content_arn: str, tags: List[Tag], session: Session) ->
196196
return responses
197197

198198
def find_all_tags_for_jumpstart_model(hub_name: str, hub_content_name: str, region: str, session: Session) -> List[Tag]:
199-
hub_content_versions = session.list_hub_content_versions(
199+
list_versions_response = session.list_hub_content_versions(
200200
hub_name=hub_name,
201201
hub_content_type='Model',
202202
hub_content_name=hub_content_name
203203
)
204+
hub_content_versions = list_versions_response["HubContentSummaries"]
204205

205206
tag_name_to_versions_map: Dict[CuratedHubTagName, List[str]] = {}
206207
for hub_content_version_summary in hub_content_versions:
207208
jumpstart_model = get_jumpstart_model_and_version(hub_content_version_summary)
209+
if jumpstart_model["model_id"] is None or jumpstart_model["version"] is None:
210+
continue
208211
tag_names_to_add: List[CuratedHubTagName] = find_tags_for_jumpstart_model_version(
209212
model_id=jumpstart_model["model_id"],
210213
version=jumpstart_model["version"],
@@ -215,13 +218,13 @@ def find_all_tags_for_jumpstart_model(hub_name: str, hub_content_name: str, regi
215218
for tag_name in tag_names_to_add:
216219
if tag_name not in tag_name_to_versions_map:
217220
tag_name_to_versions_map[tag_name] = []
218-
tag_name_to_versions_map[tag_name].append(jumpstart_model["version"])
221+
tag_name_to_versions_map[tag_name].append(hub_content_version_summary["HubContentVersion"])
219222

220223
tags: List[Tag] = []
221-
for tag_name, versions in tag_name_to_versions_map:
224+
for tag_name, versions in tag_name_to_versions_map.items():
222225
tags.append(Tag(
223226
key=tag_name,
224-
versions=str(versions)
227+
value=str(versions)
225228
))
226229

227230
return tags
@@ -241,11 +244,11 @@ def find_tags_for_jumpstart_model_version(model_id: str, version: str, region: s
241244
)
242245

243246
if (specs.deprecated):
244-
tags_to_add.add(CuratedHubTagName.DEPRECATED_VERSIONS_TAG)
247+
tags_to_add.append(CuratedHubTagName.DEPRECATED_VERSIONS_TAG)
245248
if (specs.inference_vulnerable):
246-
tags_to_add.add(CuratedHubTagName.INFERENCE_VULNERABLE_VERSIONS_TAG)
249+
tags_to_add.append(CuratedHubTagName.INFERENCE_VULNERABLE_VERSIONS_TAG)
247250
if (specs.training_vulnerable):
248-
tags_to_add.add(CuratedHubTagName.TRAINING_VULNERABLE_VERSIONS_TAG)
251+
tags_to_add.append(CuratedHubTagName.TRAINING_VULNERABLE_VERSIONS_TAG)
249252

250253
return tags_to_add
251254

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
from unittest.mock import Mock
1616
from sagemaker.jumpstart.types import HubArnExtractedInfo
1717
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
18+
from sagemaker.jumpstart.enums import JumpStartScriptScope
1819
from sagemaker.jumpstart.curated_hub import utils
20+
from unittest.mock import patch
21+
from sagemaker.jumpstart.curated_hub.types import (
22+
Tag,
23+
CuratedHubTagName
24+
)
1925

2026

2127
def test_get_info_from_hub_resource_arn():
@@ -168,3 +174,110 @@ def test_create_hub_bucket_if_it_does_not_exist():
168174

169175
mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
170176
assert created_hub_bucket_name == bucket_name
177+
178+
@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
179+
def test_find_tags_for_jumpstart_model_version(mock_spec_util):
180+
mock_sagemaker_session = Mock()
181+
mock_specs = Mock()
182+
mock_specs.deprecated = True
183+
mock_specs.inference_vulnerable = True
184+
mock_specs.training_vulnerable = True
185+
mock_spec_util.return_value = mock_specs
186+
187+
tags = utils.find_tags_for_jumpstart_model_version(
188+
model_id="test",
189+
version="test",
190+
region="test",
191+
session=mock_sagemaker_session
192+
)
193+
194+
mock_spec_util.assert_called_once_with(
195+
model_id="test",
196+
version="test",
197+
region="test",
198+
scope=JumpStartScriptScope.INFERENCE,
199+
tolerate_vulnerable_model = True,
200+
tolerate_deprecated_model = True,
201+
sagemaker_session=mock_sagemaker_session,
202+
)
203+
204+
assert tags == [CuratedHubTagName.DEPRECATED_VERSIONS_TAG, CuratedHubTagName.INFERENCE_VULNERABLE_VERSIONS_TAG, CuratedHubTagName.TRAINING_VULNERABLE_VERSIONS_TAG]
205+
206+
@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
207+
def test_find_tags_for_jumpstart_model_version_some_false(mock_spec_util):
208+
mock_sagemaker_session = Mock()
209+
mock_specs = Mock()
210+
mock_specs.deprecated = True
211+
mock_specs.inference_vulnerable = False
212+
mock_specs.training_vulnerable = False
213+
mock_spec_util.return_value = mock_specs
214+
215+
tags = utils.find_tags_for_jumpstart_model_version(
216+
model_id="test",
217+
version="test",
218+
region="test",
219+
session=mock_sagemaker_session
220+
)
221+
222+
mock_spec_util.assert_called_once_with(
223+
model_id="test",
224+
version="test",
225+
region="test",
226+
scope=JumpStartScriptScope.INFERENCE,
227+
tolerate_vulnerable_model = True,
228+
tolerate_deprecated_model = True,
229+
sagemaker_session=mock_sagemaker_session,
230+
)
231+
232+
assert tags == [CuratedHubTagName.DEPRECATED_VERSIONS_TAG]
233+
234+
@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
235+
def test_find_all_tags_for_jumpstart_model(mock_spec_util):
236+
mock_sagemaker_session = Mock()
237+
mock_sagemaker_session.list_hub_content_versions.return_value = {
238+
"HubContentSummaries": [
239+
{
240+
"HubContentVersion": "1.0.0",
241+
"search_keywords": [
242+
"@jumpstart-model-id:model-one-pytorch",
243+
"@jumpstart-model-version:1.0.3",
244+
]
245+
},
246+
{
247+
"HubContentVersion": "2.0.0",
248+
"search_keywords": [
249+
"@jumpstart-model-id:model-four-huggingface",
250+
"@jumpstart-model-version:2.0.2",
251+
]
252+
},
253+
{
254+
"HubContentVersion": "3.0.0",
255+
"search_keywords": []
256+
}
257+
]
258+
}
259+
260+
mock_specs = Mock()
261+
mock_specs.deprecated = True
262+
mock_specs.inference_vulnerable = True
263+
mock_specs.training_vulnerable = True
264+
mock_spec_util.return_value = mock_specs
265+
266+
tags = utils.find_all_tags_for_jumpstart_model(
267+
hub_name="test",
268+
hub_content_name="test",
269+
region="test",
270+
session=mock_sagemaker_session
271+
)
272+
273+
mock_sagemaker_session.list_hub_content_versions.assert_called_once_with(
274+
hub_name="test",
275+
hub_content_type='Model',
276+
hub_content_name="test",
277+
)
278+
279+
assert tags == [
280+
Tag(key=CuratedHubTagName.DEPRECATED_VERSIONS_TAG, value=str(["1.0.0", "2.0.0"])),
281+
Tag(key=CuratedHubTagName.INFERENCE_VULNERABLE_VERSIONS_TAG, value=str(["1.0.0", "2.0.0"])),
282+
Tag(key=CuratedHubTagName.TRAINING_VULNERABLE_VERSIONS_TAG, value=str(["1.0.0", "2.0.0"]))
283+
]

0 commit comments

Comments
 (0)