51
51
tag_hub_content ,
52
52
get_jumpstart_model_and_version ,
53
53
find_jumpstart_tags_for_hub_content ,
54
- get_latest_version_for_model ,
55
54
summary_list_from_list_api_response ,
56
- generate_unique_hub_content_model_name
57
55
)
58
56
from sagemaker .jumpstart .curated_hub .types import (
59
57
HubContentDocument_v2 ,
60
58
JumpStartModelInfo ,
61
59
S3ObjectLocation ,
62
60
CuratedHubTag ,
63
61
HubContentSummary ,
64
- CuratedHubModelInfo ,
65
62
)
66
63
67
64
LIST_HUB_CACHE = None
@@ -215,13 +212,23 @@ def _is_invalid_model_list_input(self, model_list: List[Dict[str, str]]) -> bool
215
212
return True
216
213
return False
217
214
215
+ def _populate_latest_model_version (self , model : Dict [str , str ]) -> Dict [str , str ]:
216
+ """Populates the lastest version of a model from specs no matter what is passed.
217
+
218
+ Returns model ({ model_id: str, version: str })
219
+ """
220
+ model_specs = utils .verify_model_region_and_return_specs (
221
+ model ["model_id" ], "*" , JumpStartScriptScope .INFERENCE , self .region
222
+ )
223
+ return {"model_id" : model ["model_id" ], "version" : model_specs .version }
224
+
218
225
def _get_jumpstart_models_in_hub (self ) -> List [HubContentSummary ]:
219
226
hub_models = summary_list_from_list_api_response (self .list_models ())
220
227
return [model for model in hub_models if get_jumpstart_model_and_version (model ) is not None ]
221
228
222
229
def _determine_models_to_sync (
223
- self , model_to_sync : List [JumpStartModelInfo ], models_in_hub : List [ CuratedHubModelInfo ]
224
- ) -> List [CuratedHubModelInfo ]:
230
+ self , model_list : List [JumpStartModelInfo ], models_in_hub : Dict [ str , HubContentSummary ]
231
+ ) -> List [JumpStartModelInfo ]:
225
232
"""Determines which models from `sync` params to sync into the CuratedHub.
226
233
227
234
Algorithm:
@@ -232,22 +239,11 @@ def _determine_models_to_sync(
232
239
in Hub, don't sync. If newer version in Hub, don't sync. If older version in Hub,
233
240
sync that model.
234
241
"""
235
- hub_content_jumpstart_model_id_map = {model .jumpstart_model_info .model_id : model for model in models_in_hub }
236
- models_to_sync : List [CuratedHubModelInfo ] = []
237
- for public_hub_model in model_to_sync :
238
- matched_model = hub_content_jumpstart_model_id_map .get (public_hub_model .model_id )
239
- if not matched_model :
240
- models_to_sync .append (CuratedHubModelInfo (
241
- jumpstart_model_info = public_hub_model ,
242
- hub_content_model_id = generate_unique_hub_content_model_name (public_hub_model .model_id ),
243
- hub_content_version = "*"
244
- ))
245
- elif Version (matched_model .jumpstart_model_info .version ) < Version (public_hub_model .version ):
246
- models_to_sync .append (CuratedHubModelInfo (
247
- jumpstart_model_info = public_hub_model ,
248
- hub_content_model_id = matched_model .hub_content_model_id ,
249
- hub_content_version = "*"
250
- ))
242
+ models_to_sync = []
243
+ for model in model_list :
244
+ matched_model = models_in_hub .get (model .model_id )
245
+ if not matched_model or Version (matched_model .hub_content_version ) < Version (model .version ):
246
+ models_to_sync .append (model )
251
247
252
248
return models_to_sync
253
249
@@ -265,30 +261,24 @@ def sync(self, model_list: List[Dict[str, str]]):
265
261
)
266
262
267
263
# Retrieve latest version of unspecified JumpStart model versions
268
- model_version_list : List [ JumpStartModelInfo ] = []
264
+ model_version_list = []
269
265
for model in model_list :
270
- model_id = model .get ("model_id" )
271
266
version = model .get ("version" , "*" )
272
267
if version == "*" :
273
- version = get_latest_version_for_model ( model_id = model_id , region = self .region )
268
+ model = self ._populate_latest_model_version ( model )
274
269
JUMPSTART_LOGGER .warning (
275
270
"No version specified for model %s. Using version %s" ,
276
- model_id ,
277
- version ,
271
+ model [ " model_id" ] ,
272
+ model [ " version" ] ,
278
273
)
279
- model_version_list .append (JumpStartModelInfo (model_id , version ))
280
-
281
- jumpstart_models_in_hub = self ._get_jumpstart_models_in_hub ()
282
- curated_models = [
283
- CuratedHubModelInfo (
284
- jumpstart_model_info = get_jumpstart_model_and_version (model ),
285
- hub_content_model_id = model .hub_content_name ,
286
- hub_content_version = model .hub_content_version
287
- ) for model in jumpstart_models_in_hub
288
- ]
289
- models_to_sync = self ._determine_models_to_sync (model_version_list , curated_models )
274
+ model_version_list .append (JumpStartModelInfo (model ["model_id" ], model ["version" ]))
275
+
276
+ js_models_in_hub = self ._get_jumpstart_models_in_hub ()
277
+ mapped_models_in_hub = {model .hub_content_name : model for model in js_models_in_hub }
278
+
279
+ models_to_sync = self ._determine_models_to_sync (model_version_list , mapped_models_in_hub )
290
280
JUMPSTART_LOGGER .warning (
291
- "Syncing the following models into Hub %s: %s" , self .hub_name , [ model . jumpstart_model_info for model in models_to_sync ]
281
+ "Syncing the following models into Hub %s: %s" , self .hub_name , models_to_sync
292
282
)
293
283
294
284
# Delete old models?
@@ -322,11 +312,11 @@ def sync(self, model_list: List[Dict[str, str]]):
322
312
f"Failures when importing models to curated hub in parallel: { failed_imports } "
323
313
)
324
314
325
- def _sync_public_model_to_hub (self , model : CuratedHubModelInfo , thread_num : int ):
315
+ def _sync_public_model_to_hub (self , model : JumpStartModelInfo , thread_num : int ):
326
316
"""Syncs a public JumpStart model version to the Hub. Runs in parallel."""
327
317
model_specs = utils .verify_model_region_and_return_specs (
328
- model_id = model .jumpstart_model_info . model_id ,
329
- version = model .jumpstart_model_info . version ,
318
+ model_id = model .model_id ,
319
+ version = model .version ,
330
320
region = self .region ,
331
321
scope = JumpStartScriptScope .INFERENCE ,
332
322
sagemaker_session = self ._sagemaker_session ,
@@ -335,7 +325,7 @@ def _sync_public_model_to_hub(self, model: CuratedHubModelInfo, thread_num: int)
335
325
336
326
dest_location = S3ObjectLocation (
337
327
bucket = self .hub_storage_location .bucket ,
338
- key = f"{ self .hub_storage_location .key } /curated_models/{ model .jumpstart_model_info . model_id } /{ model . jumpstart_model_info .version } " ,
328
+ key = f"{ self .hub_storage_location .key } /curated_models/{ model .model_id } /{ model .version } " ,
339
329
)
340
330
src_files = file_generator .generate_file_infos_from_model_specs (
341
331
model_specs , studio_specs , self .region , self ._s3_client
@@ -357,16 +347,16 @@ def _sync_public_model_to_hub(self, model: CuratedHubModelInfo, thread_num: int)
357
347
label = dest_location .key ,
358
348
).execute ()
359
349
else :
360
- JUMPSTART_LOGGER .warning ("Nothing to copy for %s v%s" , model .jumpstart_model_info . model_id , model . jumpstart_model_info .version )
350
+ JUMPSTART_LOGGER .warning ("Nothing to copy for %s v%s" , model .model_id , model .version )
361
351
362
352
# TODO: Tag model if specs say it is deprecated or training/inference
363
353
# vulnerable. Update tag of HubContent ARN without version.
364
354
# Versioned ARNs are not onboarded to Tagris.
365
355
tags = []
366
356
367
357
search_keywords = [
368
- f"{ JUMPSTART_HUB_MODEL_ID_TAG_PREFIX } :{ model .jumpstart_model_info . model_id } " ,
369
- f"{ JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX } :{ model .jumpstart_model_info . version } " ,
358
+ f"{ JUMPSTART_HUB_MODEL_ID_TAG_PREFIX } :{ model .model_id } " ,
359
+ f"{ JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX } :{ model .version } " ,
370
360
f"{ FRAMEWORK_TAG_PREFIX } :{ model_specs .get_framework ()} " ,
371
361
f"{ TASK_TAG_PREFIX } :TODO: pull from specs" ,
372
362
]
@@ -375,8 +365,8 @@ def _sync_public_model_to_hub(self, model: CuratedHubModelInfo, thread_num: int)
375
365
376
366
self ._sagemaker_session .import_hub_content (
377
367
document_schema_version = HubContentDocument_v2 .SCHEMA_VERSION ,
378
- hub_content_name = model .hub_content_model_id ,
379
- hub_content_version = model .hub_content_version ,
368
+ hub_content_name = model .model_id ,
369
+ hub_content_version = model .version ,
380
370
hub_name = self .hub_name ,
381
371
hub_content_document = hub_content_document ,
382
372
hub_content_type = HubContentType .MODEL ,
0 commit comments