13
13
"""This module provides the JumpStart Curated Hub class."""
14
14
from __future__ import absolute_import
15
15
from concurrent import futures
16
- from functools import lru_cache
16
+ from functools import lru_cache
17
17
from datetime import datetime
18
18
import json
19
19
import traceback
37
37
from sagemaker .jumpstart .curated_hub .sync .request import HubSyncRequestFactory
38
38
from sagemaker .jumpstart .enums import JumpStartScriptScope
39
39
from sagemaker .session import Session
40
- from sagemaker .jumpstart .constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION , JUMPSTART_LOGGER
40
+ from sagemaker .jumpstart .constants import (
41
+ DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
42
+ JUMPSTART_LOGGER ,
43
+ )
41
44
from sagemaker .jumpstart .types import (
42
45
DescribeHubResponse ,
43
46
DescribeHubContentsResponse ,
61
64
)
62
65
from sagemaker .utils import TagsDict
63
66
67
+
64
68
class CuratedHub :
65
69
"""Class for creating and managing a curated JumpStart hub"""
66
70
@@ -98,7 +102,9 @@ def _fetch_hub_bucket_name(self) -> str:
98
102
if hub_output_location :
99
103
location = create_s3_object_reference_from_uri (hub_output_location )
100
104
return location .bucket
101
- default_bucket_name = generate_default_hub_bucket_name (self ._sagemaker_session )
105
+ default_bucket_name = generate_default_hub_bucket_name (
106
+ self ._sagemaker_session
107
+ )
102
108
JUMPSTART_LOGGER .warning (
103
109
"There is not a Hub bucket associated with %s. Using %s" ,
104
110
self .hub_name ,
@@ -118,7 +124,9 @@ def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> N
118
124
"""Generates an ``S3ObjectLocation`` given a Hub name."""
119
125
hub_bucket_name = bucket_name or self ._fetch_hub_bucket_name ()
120
126
curr_timestamp = datetime .now ().timestamp ()
121
- return S3ObjectLocation (bucket = hub_bucket_name , key = f"{ self .hub_name } -{ curr_timestamp } " )
127
+ return S3ObjectLocation (
128
+ bucket = hub_bucket_name , key = f"{ self .hub_name } -{ curr_timestamp } "
129
+ )
122
130
123
131
def create (
124
132
self ,
@@ -151,29 +159,30 @@ def describe(self) -> DescribeHubResponse:
151
159
152
160
return hub_description
153
161
154
-
155
162
def list_models (self , clear_cache : bool = True , ** kwargs ) -> List [Dict [str , Any ]]:
156
163
"""Lists the models in this Curated Hub.
157
164
158
- This function caches the models in local memory
165
+ This function caches the models in local memory
159
166
160
167
**kwargs: Passed to invocation of ``Session:list_hub_contents``.
161
168
"""
162
169
if clear_cache :
163
170
self ._list_hubs_cache = None
164
171
if self ._list_hubs_cache is None :
165
- hub_content_summaries = self ._sagemaker_session .list_hub_contents (
166
- hub_name = self .hub_name , hub_content_type = HubContentType .MODEL , ** kwargs
167
- )
168
- self ._list_hubs_cache = hub_content_summaries
172
+ hub_content_summaries = self ._sagemaker_session .list_hub_contents (
173
+ hub_name = self .hub_name , hub_content_type = HubContentType .MODEL , ** kwargs
174
+ )
175
+ self ._list_hubs_cache = hub_content_summaries
169
176
return self ._list_hubs_cache
170
177
171
178
def describe_model (
172
179
self , model_name : str , model_version : str = "*"
173
180
) -> DescribeHubContentsResponse :
174
181
"""Returns descriptive information about the Hub Model"""
175
182
176
- hub_content_description : Dict [str , Any ] = self ._sagemaker_session .describe_hub_content (
183
+ hub_content_description : Dict [
184
+ str , Any
185
+ ] = self ._sagemaker_session .describe_hub_content (
177
186
hub_name = self .hub_name ,
178
187
hub_content_name = model_name ,
179
188
hub_content_version = model_version ,
@@ -219,10 +228,16 @@ def _populate_latest_model_version(self, model: Dict[str, str]) -> Dict[str, str
219
228
220
229
def _get_jumpstart_models_in_hub (self ) -> List [HubContentSummary ]:
221
230
hub_models = summary_list_from_list_api_response (self .list_models ())
222
- return [model for model in hub_models if get_jumpstart_model_and_version (model ) is not None ]
231
+ return [
232
+ model
233
+ for model in hub_models
234
+ if get_jumpstart_model_and_version (model ) is not None
235
+ ]
223
236
224
237
def _determine_models_to_sync (
225
- self , model_list : List [JumpStartModelInfo ], models_in_hub : Dict [str , HubContentSummary ]
238
+ self ,
239
+ model_list : List [JumpStartModelInfo ],
240
+ models_in_hub : Dict [str , HubContentSummary ],
226
241
) -> List [JumpStartModelInfo ]:
227
242
"""Determines which models from `sync` params to sync into the CuratedHub.
228
243
@@ -286,14 +301,22 @@ def sync(self, model_list: List[Dict[str, str]]):
286
301
model ["model_id" ],
287
302
model ["version" ],
288
303
)
289
- model_version_list .append (JumpStartModelInfo (model ["model_id" ], model ["version" ]))
304
+ model_version_list .append (
305
+ JumpStartModelInfo (model ["model_id" ], model ["version" ])
306
+ )
290
307
291
308
js_models_in_hub = self ._get_jumpstart_models_in_hub ()
292
- mapped_models_in_hub = {model .hub_content_name : model for model in js_models_in_hub }
309
+ mapped_models_in_hub = {
310
+ model .hub_content_name : model for model in js_models_in_hub
311
+ }
293
312
294
- models_to_sync = self ._determine_models_to_sync (model_version_list , mapped_models_in_hub )
313
+ models_to_sync = self ._determine_models_to_sync (
314
+ model_version_list , mapped_models_in_hub
315
+ )
295
316
JUMPSTART_LOGGER .warning (
296
- "Syncing the following models into Hub %s: %s" , self .hub_name , models_to_sync
317
+ "Syncing the following models into Hub %s: %s" ,
318
+ self .hub_name ,
319
+ models_to_sync ,
297
320
)
298
321
299
322
# Delete old models?
@@ -305,7 +328,9 @@ def sync(self, model_list: List[Dict[str, str]]):
305
328
thread_name_prefix = "import-models-to-curated-hub" ,
306
329
) as import_executor :
307
330
for thread_num , model in enumerate (models_to_sync ):
308
- task = import_executor .submit (self ._sync_public_model_to_hub , model , thread_num )
331
+ task = import_executor .submit (
332
+ self ._sync_public_model_to_hub , model , thread_num
333
+ )
309
334
tasks .append (task )
310
335
311
336
# Handle failed imports
@@ -318,7 +343,9 @@ def sync(self, model_list: List[Dict[str, str]]):
318
343
{
319
344
"Exception" : exception ,
320
345
"Traceback" : "" .join (
321
- traceback .TracebackException .from_exception (exception ).format ()
346
+ traceback .TracebackException .from_exception (
347
+ exception
348
+ ).format ()
322
349
),
323
350
}
324
351
)
@@ -362,7 +389,9 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
362
389
label = dest_location .key ,
363
390
).execute ()
364
391
else :
365
- JUMPSTART_LOGGER .warning ("Nothing to copy for %s v%s" , model .model_id , model .version )
392
+ JUMPSTART_LOGGER .warning (
393
+ "Nothing to copy for %s v%s" , model .model_id , model .version
394
+ )
366
395
367
396
# TODO: Tag model if specs say it is deprecated or training/inference
368
397
# vulnerable. Update tag of HubContent ARN without version.
@@ -403,44 +432,47 @@ def _fetch_studio_specs(self, model_specs: JumpStartModelSpecs) -> Dict[str, Any
403
432
)
404
433
return json .loads (response ["Body" ].read ().decode ("utf-8" ))
405
434
406
-
407
435
def scan_and_tag_models (self , model_list : List [Dict [str , str ]] = None ) -> None :
408
436
"""Scans the Hub for JumpStart models and tags the HubContent.
409
-
437
+
410
438
If the scan detects a model is deprecated or vulnerable, it will tag the HubContent.
411
439
The tags that will be added are based off the specifications in the JumpStart public hub:
412
440
1. "deprecated_versions" -> If the public hub model is deprecated
413
441
2. "inference_vulnerable_versions" -> If the public hub model has inference vulnerabilities
414
442
3. "training_vulnerable_versions" -> If the public hub model has training vulnerabilities
415
443
416
- The tag value will be a list of versions in the Curated Hub that fall under those keys.
444
+ The tag value will be a list of versions in the Curated Hub that fall under those keys.
417
445
For example, if model_a version_a is deprecated and inference is vulnerable, the
418
- HubContent for `model_a` will have tags [{"deprecated_versions": [version_a]},
446
+ HubContent for `model_a` will have tags [{"deprecated_versions": [version_a]},
419
447
{"inference_vulnerable_versions": [version_a]}]
420
448
421
- If models are passed in,
449
+ If models are passed in,
422
450
"""
423
- JUMPSTART_LOGGER .info (
424
- "Tagging models in hub: %s" , self .hub_name
425
- )
451
+ JUMPSTART_LOGGER .info ("Tagging models in hub: %s" , self .hub_name )
426
452
if self ._is_invalid_model_list_input (model_list ):
427
453
raise ValueError (
428
454
"Model list should be a list of objects with values 'model_id'," ,
429
455
"and optional 'version'." ,
430
456
)
431
-
457
+
432
458
models_to_scan = model_list if model_list else self .list_models ()
433
- js_models_in_hub = [model for model in models_to_scan if get_jumpstart_model_and_version (model ) is not None ]
459
+ js_models_in_hub = [
460
+ model
461
+ for model in models_to_scan
462
+ if get_jumpstart_model_and_version (model ) is not None
463
+ ]
434
464
for model in js_models_in_hub :
435
- tags_to_add : List [TagsDict ] = find_unsupported_flags_for_hub_content_versions (
465
+ tags_to_add : List [
466
+ TagsDict
467
+ ] = find_unsupported_flags_for_hub_content_versions (
436
468
hub_name = self .hub_name ,
437
469
hub_content_name = model .hub_content_name ,
438
470
region = self .region ,
439
- session = self ._sagemaker_session
471
+ session = self ._sagemaker_session ,
440
472
)
441
473
tag_hub_content (
442
474
hub_content_arn = model .hub_content_arn ,
443
475
tags = tags_to_add ,
444
- session = self ._sagemaker_session
476
+ session = self ._sagemaker_session ,
445
477
)
446
- JUMPSTART_LOGGER .info ("Tagging complete!" )
478
+ JUMPSTART_LOGGER .info ("Tagging complete!" )
0 commit comments