13
13
"""This module provides the JumpStart Curated Hub class."""
14
14
from __future__ import absolute_import
15
15
from concurrent import futures
16
- from functools import lru_cache
17
16
from datetime import datetime
18
17
import json
19
18
import traceback
20
19
from typing import Optional , Dict , List , Any
21
-
22
20
import boto3
23
21
from botocore import exceptions
24
22
from botocore .client import BaseClient
@@ -102,9 +100,7 @@ def _fetch_hub_bucket_name(self) -> str:
102
100
if hub_output_location :
103
101
location = create_s3_object_reference_from_uri (hub_output_location )
104
102
return location .bucket
105
- default_bucket_name = generate_default_hub_bucket_name (
106
- self ._sagemaker_session
107
- )
103
+ default_bucket_name = generate_default_hub_bucket_name (self ._sagemaker_session )
108
104
JUMPSTART_LOGGER .warning (
109
105
"There is not a Hub bucket associated with %s. Using %s" ,
110
106
self .hub_name ,
@@ -124,9 +120,7 @@ def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> N
124
120
"""Generates an ``S3ObjectLocation`` given a Hub name."""
125
121
hub_bucket_name = bucket_name or self ._fetch_hub_bucket_name ()
126
122
curr_timestamp = datetime .now ().timestamp ()
127
- return S3ObjectLocation (
128
- bucket = hub_bucket_name , key = f"{ self .hub_name } -{ curr_timestamp } "
129
- )
123
+ return S3ObjectLocation (bucket = hub_bucket_name , key = f"{ self .hub_name } -{ curr_timestamp } " )
130
124
131
125
def create (
132
126
self ,
@@ -180,9 +174,7 @@ def describe_model(
180
174
) -> DescribeHubContentsResponse :
181
175
"""Returns descriptive information about the Hub Model"""
182
176
183
- hub_content_description : Dict [
184
- str , Any
185
- ] = self ._sagemaker_session .describe_hub_content (
177
+ hub_content_description : Dict [str , Any ] = self ._sagemaker_session .describe_hub_content (
186
178
hub_name = self .hub_name ,
187
179
hub_content_name = model_name ,
188
180
hub_content_version = model_version ,
@@ -228,11 +220,7 @@ def _populate_latest_model_version(self, model: Dict[str, str]) -> Dict[str, str
228
220
229
221
def _get_jumpstart_models_in_hub (self ) -> List [HubContentSummary ]:
230
222
hub_models = summary_list_from_list_api_response (self .list_models ())
231
- return [
232
- model
233
- for model in hub_models
234
- if get_jumpstart_model_and_version (model ) is not None
235
- ]
223
+ return [model for model in hub_models if get_jumpstart_model_and_version (model ) is not None ]
236
224
237
225
def _determine_models_to_sync (
238
226
self ,
@@ -301,18 +289,12 @@ def sync(self, model_list: List[Dict[str, str]]):
301
289
model ["model_id" ],
302
290
model ["version" ],
303
291
)
304
- model_version_list .append (
305
- JumpStartModelInfo (model ["model_id" ], model ["version" ])
306
- )
292
+ model_version_list .append (JumpStartModelInfo (model ["model_id" ], model ["version" ]))
307
293
308
294
js_models_in_hub = self ._get_jumpstart_models_in_hub ()
309
- mapped_models_in_hub = {
310
- model .hub_content_name : model for model in js_models_in_hub
311
- }
295
+ mapped_models_in_hub = {model .hub_content_name : model for model in js_models_in_hub }
312
296
313
- models_to_sync = self ._determine_models_to_sync (
314
- model_version_list , mapped_models_in_hub
315
- )
297
+ models_to_sync = self ._determine_models_to_sync (model_version_list , mapped_models_in_hub )
316
298
JUMPSTART_LOGGER .warning (
317
299
"Syncing the following models into Hub %s: %s" ,
318
300
self .hub_name ,
@@ -328,9 +310,7 @@ def sync(self, model_list: List[Dict[str, str]]):
328
310
thread_name_prefix = "import-models-to-curated-hub" ,
329
311
) as import_executor :
330
312
for thread_num , model in enumerate (models_to_sync ):
331
- task = import_executor .submit (
332
- self ._sync_public_model_to_hub , model , thread_num
333
- )
313
+ task = import_executor .submit (self ._sync_public_model_to_hub , model , thread_num )
334
314
tasks .append (task )
335
315
336
316
# Handle failed imports
@@ -343,9 +323,7 @@ def sync(self, model_list: List[Dict[str, str]]):
343
323
{
344
324
"Exception" : exception ,
345
325
"Traceback" : "" .join (
346
- traceback .TracebackException .from_exception (
347
- exception
348
- ).format ()
326
+ traceback .TracebackException .from_exception (exception ).format ()
349
327
),
350
328
}
351
329
)
@@ -389,9 +367,7 @@ def _sync_public_model_to_hub(self, model: JumpStartModelInfo, thread_num: int):
389
367
label = dest_location .key ,
390
368
).execute ()
391
369
else :
392
- JUMPSTART_LOGGER .warning (
393
- "Nothing to copy for %s v%s" , model .model_id , model .version
394
- )
370
+ JUMPSTART_LOGGER .warning ("Nothing to copy for %s v%s" , model .model_id , model .version )
395
371
396
372
# TODO: Tag model if specs say it is deprecated or training/inference
397
373
# vulnerable. Update tag of HubContent ARN without version.
@@ -457,14 +433,10 @@ def scan_and_tag_models(self, model_list: List[Dict[str, str]] = None) -> None:
457
433
458
434
models_to_scan = model_list if model_list else self .list_models ()
459
435
js_models_in_hub = [
460
- model
461
- for model in models_to_scan
462
- if get_jumpstart_model_and_version (model ) is not None
436
+ model for model in models_to_scan if get_jumpstart_model_and_version (model ) is not None
463
437
]
464
438
for model in js_models_in_hub :
465
- tags_to_add : List [
466
- TagsDict
467
- ] = find_unsupported_flags_for_hub_content_versions (
439
+ tags_to_add : List [TagsDict ] = find_unsupported_flags_for_hub_content_versions (
468
440
hub_name = self .hub_name ,
469
441
hub_content_name = model .hub_content_name ,
470
442
region = self .region ,
0 commit comments