18
18
import pytest
19
19
from mock import Mock
20
20
from sagemaker .jumpstart .curated_hub .curated_hub import CuratedHub
21
- from sagemaker .jumpstart .curated_hub .types import JumpStartModelInfo , S3ObjectLocation
21
+ from sagemaker .jumpstart .curated_hub .types import JumpStartModelInfo , S3ObjectLocation , CuratedHubModelInfo
22
22
from sagemaker .jumpstart .types import JumpStartModelSpecs
23
23
from tests .unit .sagemaker .jumpstart .constants import BASE_SPEC
24
24
from tests .unit .sagemaker .jumpstart .utils import get_spec_from_base_spec
@@ -48,7 +48,6 @@ def sagemaker_session():
48
48
sagemaker_session_mock .account_id .return_value = ACCOUNT_ID
49
49
return sagemaker_session_mock
50
50
51
-
52
51
def test_instantiates (sagemaker_session ):
53
52
hub = CuratedHub (hub_name = HUB_NAME , sagemaker_session = sagemaker_session )
54
53
assert hub .hub_name == HUB_NAME
@@ -160,36 +159,45 @@ def test_create_with_bucket_name(
160
159
sagemaker_session .create_hub .assert_called_with (** request )
161
160
assert response == {"HubArn" : f"arn:aws:sagemaker:us-east-1:123456789123:hub/{ hub_name } " }
162
161
163
-
162
+ @ patch ( "sagemaker.jumpstart.curated_hub.curated_hub.generate_unique_hub_content_model_name" )
164
163
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
165
164
@patch (f"{ MODULE_PATH } ._sync_public_model_to_hub" )
166
165
@patch (f"{ MODULE_PATH } .list_models" )
167
166
def test_sync_kicks_off_parallel_syncs (
168
- mock_list_models , mock_sync_public_models , mock_get_model_specs , sagemaker_session
167
+ mock_list_models , mock_sync_public_models , mock_get_model_specs , mock_model_id_generation , sagemaker_session
169
168
):
170
169
mock_get_model_specs .side_effect = get_spec_from_base_spec
171
170
mock_list_models .return_value = {"HubContentSummaries" : []}
172
171
hub_name = "mock_hub_name"
173
172
model_one = {"model_id" : "mock-model-one-huggingface" }
174
173
model_two = {"model_id" : "mock-model-two-pytorch" , "version" : "1.0.2" }
175
174
mock_sync_public_models .return_value = ""
175
+ mock_model_id_generation .return_value = "test_model_id"
176
176
hub = CuratedHub (hub_name = hub_name , sagemaker_session = sagemaker_session )
177
177
178
178
hub .sync ([model_one , model_two ])
179
179
180
180
mock_sync_public_models .assert_has_calls (
181
181
[
182
- mock .call (JumpStartModelInfo ("mock-model-one-huggingface" , "*" ), 0 ),
183
- mock .call (JumpStartModelInfo ("mock-model-two-pytorch" , "1.0.2" ), 1 ),
182
+ mock .call (CuratedHubModelInfo (
183
+ jumpstart_model_info = JumpStartModelInfo ("mock-model-one-huggingface" , "*" ),
184
+ hub_content_model_id = "test_model_id" ,
185
+ hub_content_version = "*"
186
+ ), 0 ),
187
+ mock .call (CuratedHubModelInfo (
188
+ jumpstart_model_info = JumpStartModelInfo ("mock-model-two-pytorch" , "1.0.2" ),
189
+ hub_content_model_id = "test_model_id" ,
190
+ hub_content_version = "*"
191
+ ), 1 ),
184
192
]
185
193
)
186
194
187
-
195
+ @ patch ( "sagemaker.jumpstart.curated_hub.curated_hub.generate_unique_hub_content_model_name" )
188
196
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
189
197
@patch (f"{ MODULE_PATH } ._sync_public_model_to_hub" )
190
198
@patch (f"{ MODULE_PATH } ._list_models" )
191
199
def test_sync_filters_models_that_exist_in_hub (
192
- mock_list_models , mock_sync_public_models , mock_get_model_specs , sagemaker_session
200
+ mock_list_models , mock_sync_public_models , mock_get_model_specs , mock_model_id_generation , sagemaker_session
193
201
):
194
202
mock_get_model_specs .side_effect = get_spec_from_base_spec
195
203
mock_list_models .return_value = {
@@ -217,30 +225,35 @@ def test_sync_filters_models_that_exist_in_hub(
217
225
model_one = {"model_id" : "mock-pytorch-model-does-not-exist" }
218
226
model_two = {"model_id" : "mock-pytorch-model-already-exists-in-hub" , "version" : "1.0.2" }
219
227
mock_sync_public_models .return_value = ""
228
+ mock_model_id_generation .return_value = "test_model_id"
220
229
hub = CuratedHub (hub_name = hub_name , sagemaker_session = sagemaker_session )
221
230
222
231
hub .sync ([model_one , model_two ])
223
232
224
233
mock_sync_public_models .assert_called_once_with (
225
- JumpStartModelInfo ("mock-model-does-not-exist" , "*" ), 0
234
+ CuratedHubModelInfo (
235
+ jumpstart_model_info = JumpStartModelInfo ("mock-model-does-not-exist" , "*" ),
236
+ hub_content_model_id = "test_model_id" ,
237
+ hub_content_version = "*"
238
+ ), 0
226
239
)
227
240
228
-
241
+ @ patch ( "sagemaker.jumpstart.curated_hub.curated_hub.generate_unique_hub_content_model_name" )
229
242
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
230
243
@patch (f"{ MODULE_PATH } ._sync_public_model_to_hub" )
231
244
@patch (f"{ MODULE_PATH } ._list_models" )
232
245
def test_sync_updates_old_models_in_hub (
233
- mock_list_models , mock_sync_public_models , mock_get_model_specs , sagemaker_session
246
+ mock_list_models , mock_sync_public_models , mock_get_model_specs , mock_model_id_generation , sagemaker_session
234
247
):
235
248
mock_get_model_specs .side_effect = get_spec_from_base_spec
236
249
mock_list_models .return_value = {
237
250
"HubContentSummaries" : [
238
251
{
239
- "name" : "mock -model-two-pytorch" ,
252
+ "name" : "test -model-two-pytorch" ,
240
253
"version" : "1.0.1" ,
241
254
"search_keywords" : [
242
- "@jumpstart-model-id:model-two-pytorch" ,
243
- "@jumpstart-model-version:1.0.2 " ,
255
+ "@jumpstart-model-id:mock- model-two-pytorch" ,
256
+ "@jumpstart-model-version:1.0.0 " ,
244
257
],
245
258
},
246
259
{
@@ -262,14 +275,23 @@ def test_sync_updates_old_models_in_hub(
262
275
model_one = {"model_id" : "mock-model-one-huggingface" }
263
276
model_two = {"model_id" : "mock-model-two-pytorch" , "version" : "1.0.2" }
264
277
mock_sync_public_models .return_value = ""
278
+ mock_model_id_generation .return_value = "test_model_id"
265
279
hub = CuratedHub (hub_name = hub_name , sagemaker_session = sagemaker_session )
266
280
267
281
hub .sync ([model_one , model_two ])
268
282
269
283
mock_sync_public_models .assert_has_calls (
270
284
[
271
- mock .call (JumpStartModelInfo ("mock-model-one-huggingface" , "*" ), 0 ),
272
- mock .call (JumpStartModelInfo ("mock-model-two-pytorch" , "1.0.2" ), 1 ),
285
+ mock .call (CuratedHubModelInfo (
286
+ jumpstart_model_info = JumpStartModelInfo ("mock-model-one-huggingface" , "*" ),
287
+ hub_content_model_id = "test_model_id" ,
288
+ hub_content_version = "*"
289
+ ), 0 ),
290
+ mock .call (CuratedHubModelInfo (
291
+ jumpstart_model_info = JumpStartModelInfo ("mock-model-two-pytorch" , "1.0.2" ),
292
+ hub_content_model_id = "test-model-two-pytorch" ,
293
+ hub_content_version = "*"
294
+ ), 1 ),
273
295
]
274
296
)
275
297
0 commit comments