32
32
33
33
FAKE_TIME = datetime .datetime (1997 , 8 , 14 , 00 , 00 , 00 )
34
34
35
+ MOCK_GENERATED_MODEL_ID = "mock_generated_model_id"
36
+ MOCK_LIST_RESPONSE = {
37
+ "HubContentSummaries" : [
38
+ {
39
+ "HubContentName" : "mock-hub-content-single-version" ,
40
+ "HubContentVersion" : "1.0.0" ,
41
+ "HubContentSearchKeywords" : [
42
+ "@jumpstart-model-id:test-jumpstart-model-exists" ,
43
+ "@jumpstart-model-version:2.0.0" ,
44
+ ],
45
+ },
46
+ {
47
+ "HubContentName" : "mock-hub-content-hub-version-greater" ,
48
+ "HubContentVersion" : "3.0.0" ,
49
+ "HubContentSearchKeywords" : [
50
+ "@jumpstart-model-id:test-jumpstart-model-hub-version-greater" ,
51
+ "@jumpstart-model-version:2.0.0" ,
52
+ ],
53
+ },
54
+ {"HubContentName" : "test-model-no-keywords" , "HubContentVersion" : "1.0.0" , "HubContentSearchKeywords" : []},
55
+ {"HubContentName" : "test-model-no-keywords" , "HubContentVersion" : "1.0.1" , "HubContentSearchKeywords" : []},
56
+ {"HubContentName" : "test-model-no-jumpstart-keywords" , "HubContentVersion" : "1.0.2" , "HubContentSearchKeywords" : ["tag" , "tag2" ]},
57
+ {"HubContentName" : "test-model-missing-jumpstart-keywords" , "HubContentVersion" : "1.0.2" , "HubContentSearchKeywords" : ["@jumpstart-model-id:model-one-pytorch" , "tag2" ]},
58
+ {
59
+ "HubContentName" : "mock-hub-content-multiple-versions" ,
60
+ "HubContentVersion" : "1.0.0" ,
61
+ "HubContentSearchKeywords" : [
62
+ "@jumpstart-model-id:test-jumpstart-model-multiple-versions" ,
63
+ "@jumpstart-model-version:2.0.0" ,
64
+ ],
65
+ },
66
+ {
67
+ "HubContentName" : "mock-hub-content-multiple-versions" ,
68
+ "HubContentVersion" : "1.0.1" ,
69
+ "HubContentSearchKeywords" : [
70
+ "@jumpstart-model-id:test-jumpstart-model-multiple-versions" ,
71
+ "@jumpstart-model-version:2.0.1" ,
72
+ ],
73
+ },
74
+ ]
75
+ }
76
+
35
77
36
78
@pytest .fixture ()
37
79
def sagemaker_session ():
@@ -200,40 +242,20 @@ def test_sync_filters_models_that_exist_in_hub(
200
242
mock_list_models , mock_sync_public_models , mock_get_model_specs , mock_model_id_generation , sagemaker_session
201
243
):
202
244
mock_get_model_specs .side_effect = get_spec_from_base_spec
203
- mock_list_models .return_value = {
204
- "HubContentSummaries" : [
205
- {
206
- "HubContentName" : "mock-model-two-pytorch" ,
207
- "HubContentVersion" : "1.0.2" ,
208
- "HubContentSearchKeywords" : [
209
- "@jumpstart-model-id:mock-pytorch-model-already-exists-in-hub" ,
210
- "@jumpstart-model-version:1.0.2" ,
211
- ],
212
- },
213
- {"HubContentName" : "mock-model-three-nonsense" , "HubContentVersion" : "1.0.2" , "HubContentSearchKeywords" : []},
214
- {
215
- "HubContentName" : "mock-model-four-huggingface" ,
216
- "HubContentVersion" : "2.0.2" ,
217
- "HubContentSearchKeywords" : [
218
- "@jumpstart-model-id:model-four-huggingface" ,
219
- "@jumpstart-model-version:2.0.2" ,
220
- ],
221
- },
222
- ]
223
- }
245
+ mock_list_models .return_value = MOCK_LIST_RESPONSE
224
246
hub_name = "mock_hub_name"
225
- model_one = {"model_id" : "mock-pytorch -model-does-not-exist" }
226
- model_two = {"model_id" : "mock-pytorch -model-already- exists-in-hub " , "version" : "1 .0.2 " }
247
+ model_one = {"model_id" : "test-jumpstart -model-does-not-exist-pytorch " }
248
+ model_two = {"model_id" : "test-jumpstart -model-exists" , "version" : "2 .0.0 " }
227
249
mock_sync_public_models .return_value = ""
228
- mock_model_id_generation .return_value = "test_model_id"
250
+ mock_model_id_generation .return_value = MOCK_GENERATED_MODEL_ID
229
251
hub = CuratedHub (hub_name = hub_name , sagemaker_session = sagemaker_session )
230
252
231
253
hub .sync ([model_one , model_two ])
232
254
233
255
mock_sync_public_models .assert_called_once_with (
234
256
CuratedHubModelInfo (
235
- jumpstart_model_info = JumpStartModelInfo ("mock- model-does-not-exist" , "*" ),
236
- hub_content_model_id = "test_model_id" ,
257
+ jumpstart_model_info = JumpStartModelInfo ("test-jumpstart- model-does-not-exist-pytorch " , "*" ),
258
+ hub_content_model_id = MOCK_GENERATED_MODEL_ID ,
237
259
hub_content_version = "*"
238
260
), 0
239
261
)
@@ -242,56 +264,60 @@ def test_sync_filters_models_that_exist_in_hub(
242
264
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
243
265
@patch (f"{ MODULE_PATH } ._sync_public_model_to_hub" )
244
266
@patch (f"{ MODULE_PATH } ._list_models" )
245
- def test_sync_updates_old_models_in_hub (
267
+ def test_sync_set_jumpstart_model_version (
246
268
mock_list_models , mock_sync_public_models , mock_get_model_specs , mock_model_id_generation , sagemaker_session
247
269
):
248
270
mock_get_model_specs .side_effect = get_spec_from_base_spec
249
- mock_list_models .return_value = {
250
- "HubContentSummaries" : [
251
- {
252
- "name" : "test-model-two-pytorch" ,
253
- "version" : "1.0.1" ,
254
- "search_keywords" : [
255
- "@jumpstart-model-id:mock-model-two-pytorch" ,
256
- "@jumpstart-model-version:1.0.0" ,
257
- ],
258
- },
259
- {
260
- "name" : "mock-model-three-nonsense" ,
261
- "version" : "1.0.2" ,
262
- "search_keywords" : ["tag-one" , "tag-two" ],
263
- },
264
- {
265
- "name" : "mock-model-four-huggingface" ,
266
- "version" : "2.0.2" ,
267
- "search_keywords" : [
268
- "@jumpstart-model-id:model-four-huggingface" ,
269
- "@jumpstart-model-version:2.0.2" ,
270
- ],
271
- },
272
- ]
273
- }
271
+ mock_list_models .return_value = MOCK_LIST_RESPONSE
274
272
hub_name = "mock_hub_name"
275
- model_one = {"model_id" : "mock-model-one-huggingface" }
276
- model_two = {"model_id" : "mock-model-two-pytorch" , "version" : "1.0.2" }
273
+ model_one = {"model_id" : "test-jumpstart-model-does-not-exist-pytorch" , "version" : "2.0.0" }
277
274
mock_sync_public_models .return_value = ""
278
- mock_model_id_generation .return_value = "test_model_id"
275
+ mock_model_id_generation .return_value = MOCK_GENERATED_MODEL_ID
279
276
hub = CuratedHub (hub_name = hub_name , sagemaker_session = sagemaker_session )
280
277
281
- hub .sync ([model_one , model_two ])
278
+ hub .sync ([model_one ])
279
+
280
+ mock_sync_public_models .assert_called_once_with (
281
+ CuratedHubModelInfo (
282
+ jumpstart_model_info = JumpStartModelInfo ("test-jumpstart-model-does-not-exist-pytorch" , "2.0.0" ),
283
+ hub_content_model_id = MOCK_GENERATED_MODEL_ID ,
284
+ hub_content_version = "*"
285
+ ), 0
286
+ )
287
+
288
+ @patch ("sagemaker.jumpstart.curated_hub.curated_hub.generate_unique_hub_content_model_name" )
289
+ @patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
290
+ @patch (f"{ MODULE_PATH } ._sync_public_model_to_hub" )
291
+ @patch (f"{ MODULE_PATH } ._list_models" )
292
+ def test_sync_import_new_version_only_if_jumpstart_model_version_is_greater (
293
+ mock_list_models , mock_sync_public_models , mock_get_model_specs , mock_model_id_generation , sagemaker_session
294
+ ):
295
+ mock_get_model_specs .side_effect = get_spec_from_base_spec
296
+ mock_list_models .return_value = MOCK_LIST_RESPONSE
297
+ hub_name = "mock_hub_name"
298
+ model_one = {"model_id" : "test-jumpstart-model-exists" , "version" : "1.0.0" }
299
+ model_two = {"model_id" : "test-jumpstart-model-exists" , "version" : "2.0.0" }
300
+ model_three = {"model_id" : "test-jumpstart-model-exists" , "version" : "2.0.1" }
301
+ model_four = {"model_id" : "test-jumpstart-model-hub-version-greater" , "version" : "1.0.0" }
302
+ model_five = {"model_id" : "test-jumpstart-model-hub-version-greater" , "version" : "2.0.1" }
303
+ mock_sync_public_models .return_value = ""
304
+ mock_model_id_generation .return_value = MOCK_GENERATED_MODEL_ID
305
+ hub = CuratedHub (hub_name = hub_name , sagemaker_session = sagemaker_session )
306
+
307
+ hub .sync ([model_one , model_two , model_three , model_four , model_five ])
282
308
283
309
mock_sync_public_models .assert_has_calls (
284
310
[
285
311
mock .call (CuratedHubModelInfo (
286
- jumpstart_model_info = JumpStartModelInfo ("mock- model-one-huggingface " , "* " ),
287
- hub_content_model_id = "test_model_id " ,
312
+ jumpstart_model_info = JumpStartModelInfo ("test-jumpstart- model-exists " , "2.0.1 " ),
313
+ hub_content_model_id = "mock-hub-content-single-version " ,
288
314
hub_content_version = "*"
289
315
), 0 ),
290
316
mock .call (CuratedHubModelInfo (
291
- jumpstart_model_info = JumpStartModelInfo ("mock- model-two-pytorch " , "1 .0.2 " ),
292
- hub_content_model_id = "test-model-two-pytorch " ,
317
+ jumpstart_model_info = JumpStartModelInfo ("test-jumpstart- model-hub-version-greater " , "2 .0.1 " ),
318
+ hub_content_model_id = "mock-hub-content-hub-version-greater " ,
293
319
hub_content_version = "*"
294
- ), 1 ),
320
+ ), 1 )
295
321
]
296
322
)
297
323
0 commit comments