52
52
get_jumpstart_model_and_version ,
53
53
find_jumpstart_tags_for_hub_content ,
54
54
get_latest_version_for_model ,
55
- summary_list_from_list_api_response
55
+ summary_list_from_list_api_response ,
56
+ generate_unique_hub_content_model_name
56
57
)
57
58
from sagemaker .jumpstart .curated_hub .types import (
58
59
HubContentDocument_v2 ,
59
60
JumpStartModelInfo ,
60
61
S3ObjectLocation ,
61
62
CuratedHubTag ,
62
63
HubContentSummary ,
64
+ CuratedHubModelInfo ,
63
65
)
64
66
67
+ LIST_HUB_CACHE = None
65
68
66
69
class CuratedHub :
67
70
"""Class for creating and managing a curated JumpStart hub"""
@@ -152,16 +155,17 @@ def describe(self) -> DescribeHubResponse:
152
155
return hub_description
153
156
154
157
155
- def list_models (self , clear_cache : bool = True , ** kwargs ) -> Dict [str , Any ]:
158
+ def list_models (self , clear_cache : bool = True , ** kwargs ) -> List [ Dict [str , Any ] ]:
156
159
"""Lists the models in this Curated Hub
157
160
158
161
**kwargs: Passed to invocation of ``Session:list_hub_contents``.
159
162
"""
160
163
if clear_cache :
161
- self ._list_models .cache_clear ()
164
+ LIST_HUB_CACHE = None
165
+ if LIST_HUB_CACHE :
166
+ return LIST_HUB_CACHE
162
167
return self ._list_models (** kwargs )
163
168
164
- @lru_cache (maxsize = 5 )
165
169
def _list_models (self , ** kwargs ) -> Dict [str , Any ]:
166
170
"""Lists the models in this Curated Hub
167
171
@@ -212,13 +216,12 @@ def _is_invalid_model_list_input(self, model_list: List[Dict[str, str]]) -> bool
212
216
return False
213
217
214
218
def _get_jumpstart_models_in_hub (self ) -> List [HubContentSummary ]:
215
- """Returns list of `HubContent` that have been created from a JumpStart model."""
216
- hub_models : List [HubContentSummary ] = summary_list_from_list_api_response (self .list_models ())
219
+ hub_models = summary_list_from_list_api_response (self .list_models ())
217
220
return [model for model in hub_models if get_jumpstart_model_and_version (model ) is not None ]
218
221
219
222
def _determine_models_to_sync (
220
- self , model_to_sync : List [JumpStartModelInfo ], models_in_hub : List [JumpStartModelInfo ]
221
- ) -> List [JumpStartModelInfo ]:
223
+ self , model_to_sync : List [JumpStartModelInfo ], models_in_hub : List [CuratedHubModelInfo ]
224
+ ) -> List [CuratedHubModelInfo ]:
222
225
"""Determines which models from `sync` params to sync into the CuratedHub.
223
226
224
227
Algorithm:
@@ -229,12 +232,22 @@ def _determine_models_to_sync(
229
232
in Hub, don't sync. If newer version in Hub, don't sync. If older version in Hub,
230
233
sync that model.
231
234
"""
232
- jumpstart_model_id_to_model_info_map = {model .model_id : model for model in models_in_hub }
233
- models_to_sync : List [JumpStartModelInfo ] = []
234
- for model in model_to_sync :
235
- matched_model = jumpstart_model_id_to_model_info_map .get (model .model_id )
236
- if not matched_model or Version (matched_model .version ) < Version (model .version ):
237
- models_to_sync .append (model )
235
+ hub_content_jumpstart_model_id_map = {model .jumpstart_model_info .model_id : model for model in models_in_hub }
236
+ models_to_sync : List [CuratedHubModelInfo ] = []
237
+ for public_hub_model in model_to_sync :
238
+ matched_model = hub_content_jumpstart_model_id_map .get (public_hub_model .model_id )
239
+ if not matched_model :
240
+ models_to_sync .append (CuratedHubModelInfo (
241
+ jumpstart_model_info = public_hub_model ,
242
+ hub_content_model_id = generate_unique_hub_content_model_name (public_hub_model .model_id ),
243
+ hub_content_version = "*"
244
+ ))
245
+ elif Version (matched_model .jumpstart_model_info .version ) < Version (public_hub_model .version ):
246
+ models_to_sync .append (CuratedHubModelInfo (
247
+ jumpstart_model_info = public_hub_model ,
248
+ hub_content_model_id = matched_model .hub_content_model_id ,
249
+ hub_content_version = "*"
250
+ ))
238
251
239
252
return models_to_sync
240
253
@@ -265,12 +278,17 @@ def sync(self, model_list: List[Dict[str, str]]):
265
278
)
266
279
model_version_list .append (JumpStartModelInfo (model_id , version ))
267
280
268
- js_models_in_hub = self ._get_jumpstart_models_in_hub ()
269
- js_models_in_hub = [get_jumpstart_model_and_version (model ) for model in js_models_in_hub ]
270
-
271
- models_to_sync = self ._determine_models_to_sync (model_version_list , js_models_in_hub )
281
+ jumpstart_models_in_hub = self ._get_jumpstart_models_in_hub ()
282
+ curated_models = [
283
+ CuratedHubModelInfo (
284
+ jumpstart_model_id = get_jumpstart_model_and_version (model ),
285
+ hub_content_model_id = model .hub_content_name ,
286
+ hub_content_version = model .hub_content_version
287
+ ) for model in jumpstart_models_in_hub
288
+ ]
289
+ models_to_sync = self ._determine_models_to_sync (model_version_list , curated_models )
272
290
JUMPSTART_LOGGER .warning (
273
- "Syncing the following models into Hub %s: %s" , self .hub_name , models_to_sync
291
+ "Syncing the following models into Hub %s: %s" , self .hub_name , [ model . jumpstart_model_info for model in models_to_sync ]
274
292
)
275
293
276
294
# Delete old models?
@@ -304,11 +322,11 @@ def sync(self, model_list: List[Dict[str, str]]):
304
322
f"Failures when importing models to curated hub in parallel: { failed_imports } "
305
323
)
306
324
307
- def _sync_public_model_to_hub (self , model : JumpStartModelInfo , thread_num : int ):
325
+ def _sync_public_model_to_hub (self , model : CuratedHubModelInfo , thread_num : int ):
308
326
"""Syncs a public JumpStart model version to the Hub. Runs in parallel."""
309
327
model_specs = utils .verify_model_region_and_return_specs (
310
- model_id = model .model_id ,
311
- version = model .version ,
328
+ model_id = model .jumpstart_model_info . model_id ,
329
+ version = model .jumpstart_model_info . version ,
312
330
region = self .region ,
313
331
scope = JumpStartScriptScope .INFERENCE ,
314
332
sagemaker_session = self ._sagemaker_session ,
@@ -317,7 +335,7 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
317
335
318
336
dest_location = S3ObjectLocation (
319
337
bucket = self .hub_storage_location .bucket ,
320
- key = f"{ self .hub_storage_location .key } /curated_models/{ model .model_id } /{ model .version } " ,
338
+ key = f"{ self .hub_storage_location .key } /curated_models/{ model .jumpstart_model_info . model_id } /{ model . jumpstart_model_info .version } " ,
321
339
)
322
340
src_files = file_generator .generate_file_infos_from_model_specs (
323
341
model_specs , studio_specs , self .region , self ._s3_client
@@ -339,16 +357,16 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
339
357
label = dest_location .key ,
340
358
).execute ()
341
359
else :
342
- JUMPSTART_LOGGER .warning ("Nothing to copy for %s v%s" , model .model_id , model .version )
360
+ JUMPSTART_LOGGER .warning ("Nothing to copy for %s v%s" , model .jumpstart_model_info . model_id , model . jumpstart_model_info .version )
343
361
344
362
# TODO: Tag model if specs say it is deprecated or training/inference
345
363
# vulnerable. Update tag of HubContent ARN without version.
346
364
# Versioned ARNs are not onboarded to Tagris.
347
365
tags = []
348
366
349
367
search_keywords = [
350
- f"{ JUMPSTART_HUB_MODEL_ID_TAG_PREFIX } :{ model .model_id } " ,
351
- f"{ JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX } :{ model .version } " ,
368
+ f"{ JUMPSTART_HUB_MODEL_ID_TAG_PREFIX } :{ model .jumpstart_model_info . model_id } " ,
369
+ f"{ JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX } :{ model .jumpstart_model_info . version } " ,
352
370
f"{ FRAMEWORK_TAG_PREFIX } :{ model_specs .get_framework ()} " ,
353
371
f"{ TASK_TAG_PREFIX } :TODO: pull from specs" ,
354
372
]
@@ -357,8 +375,8 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
357
375
358
376
self ._sagemaker_session .import_hub_content (
359
377
document_schema_version = HubContentDocument_v2 .SCHEMA_VERSION ,
360
- hub_content_name = model .model_id ,
361
- hub_content_version = model .version ,
378
+ hub_content_name = model .hub_content_model_id ,
379
+ hub_content_version = model .hub_content_version ,
362
380
hub_name = self .hub_name ,
363
381
hub_content_document = hub_content_document ,
364
382
hub_content_type = HubContentType .MODEL ,
@@ -398,9 +416,9 @@ def scan_and_tag_models(self) -> None:
398
416
JUMPSTART_LOGGER .info (
399
417
"Tagging models in hub: %s" , self .hub_name
400
418
)
401
- models_in_hub : List [ HubContentSummary ] = self ._get_jumpstart_models_in_hub ()
419
+ js_models_in_hub = [ model for model in self .list_models () if get_jumpstart_model_and_version ( model ) is not None ]
402
420
tags_added : Dict [str , List [CuratedHubTag ]] = {}
403
- for model in models_in_hub :
421
+ for model in js_models_in_hub :
404
422
tags_to_add : List [CuratedHubTag ] = find_jumpstart_tags_for_hub_content (
405
423
hub_name = self .hub_name ,
406
424
hub_content_name = model .hub_content_name ,
0 commit comments