Skip to content

Commit 9e703a9

Browse files
committed
fix: Adding more unittests
1 parent 453bed4 commit 9e703a9

File tree

4 files changed

+99
-71
lines changed

4 files changed

+99
-71
lines changed

src/sagemaker/jumpstart/curated_hub/constants.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""This module stores constants related to SageMaker JumpStart CuratedHub."""
1414
from __future__ import absolute_import
1515

16-
JUMPSTART_HUB_MODEL_ID_TAG_PREFIX = "@jumpstart-model-id"
17-
JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX = "@jumpstart-model-version"
18-
FRAMEWORK_TAG_PREFIX = "@framework"
19-
TASK_TAG_PREFIX = "@mltask"
16+
JUMPSTART_HUB_MODEL_ID_TAG_PREFIX = "@jumpstart-model-id:"
17+
JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX = "@jumpstart-model-version:"
18+
FRAMEWORK_TAG_PREFIX = "@framework:"
19+
TASK_TAG_PREFIX = "@mltask:"

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def sync(self, model_list: List[Dict[str, str]]):
281281
jumpstart_models_in_hub = self._get_jumpstart_models_in_hub()
282282
curated_models = [
283283
CuratedHubModelInfo(
284-
jumpstart_model_id=get_jumpstart_model_and_version(model),
284+
jumpstart_model_info=get_jumpstart_model_and_version(model),
285285
hub_content_model_id=model.hub_content_name,
286286
hub_content_version=model.hub_content_version
287287
) for model in jumpstart_models_in_hub

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,25 +252,27 @@ def find_jumpstart_tags_for_model_version(model_id: str, version: str, region: s
252252

253253

254254
def get_jumpstart_model_and_version(hub_content_summary: HubContentSummary) -> Optional[JumpStartModelInfo]:
255-
jumpstart_model_id = next(
255+
jumpstart_model_id_tag = next(
256256
(
257-
tag[len(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX)+1:] # Need to remove the tag_prefix and ":"
257+
tag
258258
for tag in hub_content_summary.hub_content_search_keywords
259259
if tag.startswith(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX)
260260
),
261261
None,
262262
)
263-
jumpstart_model_version = next(
263+
jumpstart_model_version_tag = next(
264264
(
265-
tag[len(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX)+1:]
265+
tag
266266
for tag in hub_content_summary.hub_content_search_keywords
267267
if tag.startswith(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX)
268268
),
269269
None,
270270
)
271271

272-
if jumpstart_model_id is None or jumpstart_model_version is None:
272+
if jumpstart_model_id_tag is None or jumpstart_model_version_tag is None:
273273
return None
274+
jumpstart_model_id = jumpstart_model_id_tag[len(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX):] # Need to remove the tag_prefix and ":"
275+
jumpstart_model_version = jumpstart_model_version_tag[len(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX):]
274276
return JumpStartModelInfo(model_id=jumpstart_model_id, version=jumpstart_model_version)
275277

276278
def get_latest_version_for_model(model_id: str, region: str) -> str:

tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py

Lines changed: 87 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,48 @@
3232

3333
FAKE_TIME = datetime.datetime(1997, 8, 14, 00, 00, 00)
3434

35+
MOCK_GENERATED_MODEL_ID = "mock_generated_model_id"
36+
MOCK_LIST_RESPONSE = {
37+
"HubContentSummaries": [
38+
{
39+
"HubContentName": "mock-hub-content-single-version",
40+
"HubContentVersion": "1.0.0",
41+
"HubContentSearchKeywords": [
42+
"@jumpstart-model-id:test-jumpstart-model-exists",
43+
"@jumpstart-model-version:2.0.0",
44+
],
45+
},
46+
{
47+
"HubContentName": "mock-hub-content-hub-version-greater",
48+
"HubContentVersion": "3.0.0",
49+
"HubContentSearchKeywords": [
50+
"@jumpstart-model-id:test-jumpstart-model-hub-version-greater",
51+
"@jumpstart-model-version:2.0.0",
52+
],
53+
},
54+
{"HubContentName": "test-model-no-keywords", "HubContentVersion": "1.0.0", "HubContentSearchKeywords": []},
55+
{"HubContentName": "test-model-no-keywords", "HubContentVersion": "1.0.1", "HubContentSearchKeywords": []},
56+
{"HubContentName": "test-model-no-jumpstart-keywords", "HubContentVersion": "1.0.2", "HubContentSearchKeywords": ["tag", "tag2"]},
57+
{"HubContentName": "test-model-missing-jumpstart-keywords", "HubContentVersion": "1.0.2", "HubContentSearchKeywords": ["@jumpstart-model-id:model-one-pytorch", "tag2"]},
58+
{
59+
"HubContentName": "mock-hub-content-multiple-versions",
60+
"HubContentVersion": "1.0.0",
61+
"HubContentSearchKeywords": [
62+
"@jumpstart-model-id:test-jumpstart-model-multiple-versions",
63+
"@jumpstart-model-version:2.0.0",
64+
],
65+
},
66+
{
67+
"HubContentName": "mock-hub-content-multiple-versions",
68+
"HubContentVersion": "1.0.1",
69+
"HubContentSearchKeywords": [
70+
"@jumpstart-model-id:test-jumpstart-model-multiple-versions",
71+
"@jumpstart-model-version:2.0.1",
72+
],
73+
},
74+
]
75+
}
76+
3577

3678
@pytest.fixture()
3779
def sagemaker_session():
@@ -200,40 +242,20 @@ def test_sync_filters_models_that_exist_in_hub(
200242
mock_list_models, mock_sync_public_models, mock_get_model_specs, mock_model_id_generation, sagemaker_session
201243
):
202244
mock_get_model_specs.side_effect = get_spec_from_base_spec
203-
mock_list_models.return_value = {
204-
"HubContentSummaries": [
205-
{
206-
"HubContentName": "mock-model-two-pytorch",
207-
"HubContentVersion": "1.0.2",
208-
"HubContentSearchKeywords": [
209-
"@jumpstart-model-id:mock-pytorch-model-already-exists-in-hub",
210-
"@jumpstart-model-version:1.0.2",
211-
],
212-
},
213-
{"HubContentName": "mock-model-three-nonsense", "HubContentVersion": "1.0.2", "HubContentSearchKeywords": []},
214-
{
215-
"HubContentName": "mock-model-four-huggingface",
216-
"HubContentVersion": "2.0.2",
217-
"HubContentSearchKeywords": [
218-
"@jumpstart-model-id:model-four-huggingface",
219-
"@jumpstart-model-version:2.0.2",
220-
],
221-
},
222-
]
223-
}
245+
mock_list_models.return_value = MOCK_LIST_RESPONSE
224246
hub_name = "mock_hub_name"
225-
model_one = {"model_id": "mock-pytorch-model-does-not-exist"}
226-
model_two = {"model_id": "mock-pytorch-model-already-exists-in-hub", "version": "1.0.2"}
247+
model_one = {"model_id": "test-jumpstart-model-does-not-exist-pytorch"}
248+
model_two = {"model_id": "test-jumpstart-model-exists", "version": "2.0.0"}
227249
mock_sync_public_models.return_value = ""
228-
mock_model_id_generation.return_value = "test_model_id"
250+
mock_model_id_generation.return_value = MOCK_GENERATED_MODEL_ID
229251
hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session)
230252

231253
hub.sync([model_one, model_two])
232254

233255
mock_sync_public_models.assert_called_once_with(
234256
CuratedHubModelInfo(
235-
jumpstart_model_info=JumpStartModelInfo("mock-model-does-not-exist", "*"),
236-
hub_content_model_id="test_model_id",
257+
jumpstart_model_info=JumpStartModelInfo("test-jumpstart-model-does-not-exist-pytorch", "*"),
258+
hub_content_model_id=MOCK_GENERATED_MODEL_ID,
237259
hub_content_version="*"
238260
), 0
239261
)
@@ -242,56 +264,60 @@ def test_sync_filters_models_that_exist_in_hub(
242264
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
243265
@patch(f"{MODULE_PATH}._sync_public_model_to_hub")
244266
@patch(f"{MODULE_PATH}._list_models")
245-
def test_sync_updates_old_models_in_hub(
267+
def test_sync_set_jumpstart_model_version(
246268
mock_list_models, mock_sync_public_models, mock_get_model_specs, mock_model_id_generation, sagemaker_session
247269
):
248270
mock_get_model_specs.side_effect = get_spec_from_base_spec
249-
mock_list_models.return_value = {
250-
"HubContentSummaries": [
251-
{
252-
"name": "test-model-two-pytorch",
253-
"version": "1.0.1",
254-
"search_keywords": [
255-
"@jumpstart-model-id:mock-model-two-pytorch",
256-
"@jumpstart-model-version:1.0.0",
257-
],
258-
},
259-
{
260-
"name": "mock-model-three-nonsense",
261-
"version": "1.0.2",
262-
"search_keywords": ["tag-one", "tag-two"],
263-
},
264-
{
265-
"name": "mock-model-four-huggingface",
266-
"version": "2.0.2",
267-
"search_keywords": [
268-
"@jumpstart-model-id:model-four-huggingface",
269-
"@jumpstart-model-version:2.0.2",
270-
],
271-
},
272-
]
273-
}
271+
mock_list_models.return_value = MOCK_LIST_RESPONSE
274272
hub_name = "mock_hub_name"
275-
model_one = {"model_id": "mock-model-one-huggingface"}
276-
model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"}
273+
model_one = {"model_id": "test-jumpstart-model-does-not-exist-pytorch", "version": "2.0.0"}
277274
mock_sync_public_models.return_value = ""
278-
mock_model_id_generation.return_value = "test_model_id"
275+
mock_model_id_generation.return_value = MOCK_GENERATED_MODEL_ID
279276
hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session)
280277

281-
hub.sync([model_one, model_two])
278+
hub.sync([model_one])
279+
280+
mock_sync_public_models.assert_called_once_with(
281+
CuratedHubModelInfo(
282+
jumpstart_model_info=JumpStartModelInfo("test-jumpstart-model-does-not-exist-pytorch", "2.0.0"),
283+
hub_content_model_id=MOCK_GENERATED_MODEL_ID,
284+
hub_content_version="*"
285+
), 0
286+
)
287+
288+
@patch("sagemaker.jumpstart.curated_hub.curated_hub.generate_unique_hub_content_model_name")
289+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
290+
@patch(f"{MODULE_PATH}._sync_public_model_to_hub")
291+
@patch(f"{MODULE_PATH}._list_models")
292+
def test_sync_import_new_version_only_if_jumpstart_model_version_is_greater(
293+
mock_list_models, mock_sync_public_models, mock_get_model_specs, mock_model_id_generation, sagemaker_session
294+
):
295+
mock_get_model_specs.side_effect = get_spec_from_base_spec
296+
mock_list_models.return_value = MOCK_LIST_RESPONSE
297+
hub_name = "mock_hub_name"
298+
model_one = {"model_id": "test-jumpstart-model-exists", "version": "1.0.0"}
299+
model_two = {"model_id": "test-jumpstart-model-exists", "version": "2.0.0"}
300+
model_three = {"model_id": "test-jumpstart-model-exists", "version": "2.0.1"}
301+
model_four = {"model_id": "test-jumpstart-model-hub-version-greater", "version": "1.0.0"}
302+
model_five = {"model_id": "test-jumpstart-model-hub-version-greater", "version": "2.0.1"}
303+
mock_sync_public_models.return_value = ""
304+
mock_model_id_generation.return_value = MOCK_GENERATED_MODEL_ID
305+
hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session)
306+
307+
hub.sync([model_one, model_two, model_three, model_four, model_five])
282308

283309
mock_sync_public_models.assert_has_calls(
284310
[
285311
mock.call(CuratedHubModelInfo(
286-
jumpstart_model_info=JumpStartModelInfo("mock-model-one-huggingface", "*"),
287-
hub_content_model_id="test_model_id",
312+
jumpstart_model_info=JumpStartModelInfo("test-jumpstart-model-exists", "2.0.1"),
313+
hub_content_model_id="mock-hub-content-single-version",
288314
hub_content_version="*"
289315
), 0),
290316
mock.call(CuratedHubModelInfo(
291-
jumpstart_model_info=JumpStartModelInfo("mock-model-two-pytorch", "1.0.2"),
292-
hub_content_model_id="test-model-two-pytorch",
317+
jumpstart_model_info=JumpStartModelInfo("test-jumpstart-model-hub-version-greater", "2.0.1"),
318+
hub_content_model_id="mock-hub-content-hub-version-greater",
293319
hub_content_version="*"
294-
), 1),
320+
), 1)
295321
]
296322
)
297323

0 commit comments

Comments
 (0)