Skip to content

Commit 453bed4

Browse files
committed
fix: Addressing unit tests
1 parent c957e71 commit 453bed4

File tree

1 file changed

+38
-16
lines changed

1 file changed

+38
-16
lines changed

tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pytest
1919
from mock import Mock
2020
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub
21-
from sagemaker.jumpstart.curated_hub.types import JumpStartModelInfo, S3ObjectLocation
21+
from sagemaker.jumpstart.curated_hub.types import JumpStartModelInfo, S3ObjectLocation, CuratedHubModelInfo
2222
from sagemaker.jumpstart.types import JumpStartModelSpecs
2323
from tests.unit.sagemaker.jumpstart.constants import BASE_SPEC
2424
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
@@ -48,7 +48,6 @@ def sagemaker_session():
4848
sagemaker_session_mock.account_id.return_value = ACCOUNT_ID
4949
return sagemaker_session_mock
5050

51-
5251
def test_instantiates(sagemaker_session):
5352
hub = CuratedHub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session)
5453
assert hub.hub_name == HUB_NAME
@@ -160,36 +159,45 @@ def test_create_with_bucket_name(
160159
sagemaker_session.create_hub.assert_called_with(**request)
161160
assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"}
162161

163-
162+
@patch("sagemaker.jumpstart.curated_hub.curated_hub.generate_unique_hub_content_model_name")
164163
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
165164
@patch(f"{MODULE_PATH}._sync_public_model_to_hub")
166165
@patch(f"{MODULE_PATH}.list_models")
167166
def test_sync_kicks_off_parallel_syncs(
168-
mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session
167+
mock_list_models, mock_sync_public_models, mock_get_model_specs, mock_model_id_generation, sagemaker_session
169168
):
170169
mock_get_model_specs.side_effect = get_spec_from_base_spec
171170
mock_list_models.return_value = {"HubContentSummaries": []}
172171
hub_name = "mock_hub_name"
173172
model_one = {"model_id": "mock-model-one-huggingface"}
174173
model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"}
175174
mock_sync_public_models.return_value = ""
175+
mock_model_id_generation.return_value = "test_model_id"
176176
hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session)
177177

178178
hub.sync([model_one, model_two])
179179

180180
mock_sync_public_models.assert_has_calls(
181181
[
182-
mock.call(JumpStartModelInfo("mock-model-one-huggingface", "*"), 0),
183-
mock.call(JumpStartModelInfo("mock-model-two-pytorch", "1.0.2"), 1),
182+
mock.call(CuratedHubModelInfo(
183+
jumpstart_model_info=JumpStartModelInfo("mock-model-one-huggingface", "*"),
184+
hub_content_model_id="test_model_id",
185+
hub_content_version="*"
186+
), 0),
187+
mock.call(CuratedHubModelInfo(
188+
jumpstart_model_info=JumpStartModelInfo("mock-model-two-pytorch", "1.0.2"),
189+
hub_content_model_id="test_model_id",
190+
hub_content_version="*"
191+
), 1),
184192
]
185193
)
186194

187-
195+
@patch("sagemaker.jumpstart.curated_hub.curated_hub.generate_unique_hub_content_model_name")
188196
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
189197
@patch(f"{MODULE_PATH}._sync_public_model_to_hub")
190198
@patch(f"{MODULE_PATH}._list_models")
191199
def test_sync_filters_models_that_exist_in_hub(
192-
mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session
200+
mock_list_models, mock_sync_public_models, mock_get_model_specs, mock_model_id_generation, sagemaker_session
193201
):
194202
mock_get_model_specs.side_effect = get_spec_from_base_spec
195203
mock_list_models.return_value = {
@@ -217,30 +225,35 @@ def test_sync_filters_models_that_exist_in_hub(
217225
model_one = {"model_id": "mock-pytorch-model-does-not-exist"}
218226
model_two = {"model_id": "mock-pytorch-model-already-exists-in-hub", "version": "1.0.2"}
219227
mock_sync_public_models.return_value = ""
228+
mock_model_id_generation.return_value = "test_model_id"
220229
hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session)
221230

222231
hub.sync([model_one, model_two])
223232

224233
mock_sync_public_models.assert_called_once_with(
225-
JumpStartModelInfo("mock-model-does-not-exist", "*"), 0
234+
CuratedHubModelInfo(
235+
jumpstart_model_info=JumpStartModelInfo("mock-model-does-not-exist", "*"),
236+
hub_content_model_id="test_model_id",
237+
hub_content_version="*"
238+
), 0
226239
)
227240

228-
241+
@patch("sagemaker.jumpstart.curated_hub.curated_hub.generate_unique_hub_content_model_name")
229242
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
230243
@patch(f"{MODULE_PATH}._sync_public_model_to_hub")
231244
@patch(f"{MODULE_PATH}._list_models")
232245
def test_sync_updates_old_models_in_hub(
233-
mock_list_models, mock_sync_public_models, mock_get_model_specs, sagemaker_session
246+
mock_list_models, mock_sync_public_models, mock_get_model_specs, mock_model_id_generation, sagemaker_session
234247
):
235248
mock_get_model_specs.side_effect = get_spec_from_base_spec
236249
mock_list_models.return_value = {
237250
"HubContentSummaries": [
238251
{
239-
"name": "mock-model-two-pytorch",
252+
"name": "test-model-two-pytorch",
240253
"version": "1.0.1",
241254
"search_keywords": [
242-
"@jumpstart-model-id:model-two-pytorch",
243-
"@jumpstart-model-version:1.0.2",
255+
"@jumpstart-model-id:mock-model-two-pytorch",
256+
"@jumpstart-model-version:1.0.0",
244257
],
245258
},
246259
{
@@ -262,14 +275,23 @@ def test_sync_updates_old_models_in_hub(
262275
model_one = {"model_id": "mock-model-one-huggingface"}
263276
model_two = {"model_id": "mock-model-two-pytorch", "version": "1.0.2"}
264277
mock_sync_public_models.return_value = ""
278+
mock_model_id_generation.return_value = "test_model_id"
265279
hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session)
266280

267281
hub.sync([model_one, model_two])
268282

269283
mock_sync_public_models.assert_has_calls(
270284
[
271-
mock.call(JumpStartModelInfo("mock-model-one-huggingface", "*"), 0),
272-
mock.call(JumpStartModelInfo("mock-model-two-pytorch", "1.0.2"), 1),
285+
mock.call(CuratedHubModelInfo(
286+
jumpstart_model_info=JumpStartModelInfo("mock-model-one-huggingface", "*"),
287+
hub_content_model_id="test_model_id",
288+
hub_content_version="*"
289+
), 0),
290+
mock.call(CuratedHubModelInfo(
291+
jumpstart_model_info=JumpStartModelInfo("mock-model-two-pytorch", "1.0.2"),
292+
hub_content_model_id="test-model-two-pytorch",
293+
hub_content_version="*"
294+
), 1),
273295
]
274296
)
275297

0 commit comments

Comments
 (0)