34
34
DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
35
35
MODEL_TYPE_TO_MANIFEST_MAP ,
36
36
MODEL_TYPE_TO_SPECS_MAP ,
37
- DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
38
37
)
39
38
from sagemaker .jumpstart .exceptions import (
40
39
get_wildcard_model_version_msg ,
@@ -178,9 +177,7 @@ def set_manifest_file_s3_key(
178
177
}
179
178
property_name = file_mapping .get (file_type )
180
179
if not property_name :
181
- raise ValueError (
182
- self ._file_type_error_msg (file_type , manifest_only = True )
183
- )
180
+ raise ValueError (self ._file_type_error_msg (file_type , manifest_only = True ))
184
181
if key != property_name :
185
182
setattr (self , property_name , key )
186
183
self .clear ()
@@ -193,9 +190,7 @@ def get_manifest_file_s3_key(
193
190
return self ._manifest_file_s3_key
194
191
if file_type == JumpStartS3FileType .PROPRIETARY_MANIFEST :
195
192
return self ._proprietary_manifest_s3_key
196
- raise ValueError (
197
- self ._file_type_error_msg (file_type , manifest_only = True )
198
- )
193
+ raise ValueError (self ._file_type_error_msg (file_type , manifest_only = True ))
199
194
200
195
def set_s3_bucket_name (self , s3_bucket_name : str ) -> None :
201
196
"""Set s3 bucket used for cache."""
@@ -248,7 +243,8 @@ def _model_id_retrieval_function(
248
243
sm_version = utils .get_sagemaker_version ()
249
244
manifest = self ._content_cache .get (
250
245
JumpStartCachedContentKey (
251
- MODEL_TYPE_TO_MANIFEST_MAP [model_type ], self ._manifest_file_s3_map [model_type ])
246
+ MODEL_TYPE_TO_MANIFEST_MAP [model_type ], self ._manifest_file_s3_map [model_type ]
247
+ )
252
248
)[0 ].formatted_content
253
249
254
250
versions_compatible_with_sagemaker = [
@@ -265,7 +261,8 @@ def _model_id_retrieval_function(
265
261
return JumpStartVersionedModelId (model_id , sm_compatible_model_version )
266
262
267
263
versions_incompatible_with_sagemaker = [
268
- Version (header .version ) for header in manifest .values () # type: ignore
264
+ Version (header .version )
265
+ for header in manifest .values () # type: ignore
269
266
if header .model_id == model_id
270
267
]
271
268
sm_incompatible_model_version = self ._select_version (
@@ -295,9 +292,7 @@ def _model_id_retrieval_function(
295
292
raise KeyError (error_msg )
296
293
297
294
error_msg = f"Unable to find model manifest for '{ model_id } ' with version '{ version } '. "
298
- error_msg += (
299
- f"Visit { MODEL_ID_LIST_WEB_URL } for updated list of models. "
300
- )
295
+ error_msg += f"Visit { MODEL_ID_LIST_WEB_URL } for updated list of models. "
301
296
302
297
other_model_id_version = None
303
298
if model_type == JumpStartModelType .OPEN_WEIGHTS :
@@ -306,19 +301,17 @@ def _model_id_retrieval_function(
306
301
) # all versions here are incompatible with sagemaker
307
302
elif model_type == JumpStartModelType .PROPRIETARY :
308
303
all_possible_model_id_version = [
309
- header .version for header in manifest .values () # type: ignore
304
+ header .version
305
+ for header in manifest .values () # type: ignore
310
306
if header .model_id == model_id
311
307
]
312
308
other_model_id_version = (
313
- None
314
- if not all_possible_model_id_version
315
- else all_possible_model_id_version [0 ]
309
+ None if not all_possible_model_id_version else all_possible_model_id_version [0 ]
316
310
)
317
311
318
312
if other_model_id_version is not None :
319
313
error_msg += (
320
- f"Consider using model ID '{ model_id } ' with version "
321
- f"'{ other_model_id_version } '."
314
+ f"Consider using model ID '{ model_id } ' with version " f"'{ other_model_id_version } '."
322
315
)
323
316
else :
324
317
possible_model_ids = [header .model_id for header in manifest .values ()] # type: ignore
@@ -360,15 +353,15 @@ def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list],
360
353
361
354
def _is_local_metadata_mode (self ) -> bool :
362
355
"""Returns True if the cache should use local metadata mode, based off env variables."""
363
- return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os .environ
364
- and os .path .isdir (os .environ [ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE ])
365
- and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os .environ
366
- and os .path .isdir (os .environ [ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE ]))
356
+ return (
357
+ ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os .environ
358
+ and os .path .isdir (os .environ [ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE ])
359
+ and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os .environ
360
+ and os .path .isdir (os .environ [ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE ])
361
+ )
367
362
368
363
def _get_json_file (
369
- self ,
370
- key : str ,
371
- filetype : JumpStartS3FileType
364
+ self , key : str , filetype : JumpStartS3FileType
372
365
) -> Tuple [Union [dict , list ], Optional [str ]]:
373
366
"""Returns json file either from s3 or local file system.
374
367
@@ -392,21 +385,19 @@ def _get_json_md5_hash(self, key: str):
392
385
return self ._s3_client .head_object (Bucket = self .s3_bucket_name , Key = key )["ETag" ]
393
386
394
387
def _get_json_file_from_local_override (
395
- self ,
396
- key : str ,
397
- filetype : JumpStartS3FileType
388
+ self , key : str , filetype : JumpStartS3FileType
398
389
) -> Union [dict , list ]:
399
390
"""Reads json file from local filesystem and returns data."""
400
391
if filetype == JumpStartS3FileType .OPEN_WEIGHT_MANIFEST :
401
- metadata_local_root = (
402
- os . environ [ ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE ]
403
- )
392
+ metadata_local_root = os . environ [
393
+ ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE
394
+ ]
404
395
elif filetype == JumpStartS3FileType .OPEN_WEIGHT_SPECS :
405
396
metadata_local_root = os .environ [ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE ]
406
397
else :
407
398
raise ValueError (f"Unsupported file type for local override: { filetype } " )
408
399
file_path = os .path .join (metadata_local_root , key )
409
- with open (file_path , 'r' ) as f :
400
+ with open (file_path , "r" ) as f :
410
401
data = json .load (f )
411
402
return data
412
403
@@ -443,17 +434,15 @@ def _retrieval_function(
443
434
formatted_content = utils .get_formatted_manifest (formatted_body ),
444
435
md5_hash = etag ,
445
436
)
446
-
437
+
447
438
if data_type in {
448
439
JumpStartS3FileType .OPEN_WEIGHT_SPECS ,
449
440
JumpStartS3FileType .PROPRIETARY_SPECS ,
450
441
}:
451
442
formatted_body , _ = self ._get_json_file (id_info , data_type )
452
443
model_specs = JumpStartModelSpecs (formatted_body )
453
444
utils .emit_logs_based_on_model_specs (model_specs , self .get_region (), self ._s3_client )
454
- return JumpStartCachedContentValue (
455
- formatted_content = model_specs
456
- )
445
+ return JumpStartCachedContentValue (formatted_content = model_specs )
457
446
458
447
if data_type == HubContentType .MODEL :
459
448
hub_name , _ , model_name , model_version = hub_utils .get_info_from_hub_resource_arn (
@@ -463,21 +452,15 @@ def _retrieval_function(
463
452
hub_name = hub_name ,
464
453
hub_content_name = model_name ,
465
454
hub_content_version = model_version ,
466
- hub_content_type = data_type
455
+ hub_content_type = data_type ,
467
456
)
468
457
469
458
model_specs = JumpStartModelSpecs (
470
459
DescribeHubContentsResponse (hub_model_description ), is_hub_content = True
471
460
)
472
461
473
- utils .emit_logs_based_on_model_specs (
474
- model_specs ,
475
- self .get_region (),
476
- self ._s3_client
477
- )
478
- return JumpStartCachedContentValue (
479
- formatted_content = model_specs
480
- )
462
+ utils .emit_logs_based_on_model_specs (model_specs , self .get_region (), self ._s3_client )
463
+ return JumpStartCachedContentValue (formatted_content = model_specs )
481
464
482
465
if data_type == HubType .HUB :
483
466
hub_name , _ , _ , _ = hub_utils .get_info_from_hub_resource_arn (id_info )
@@ -487,9 +470,7 @@ def _retrieval_function(
487
470
formatted_content = DescribeHubResponse (hub_description )
488
471
)
489
472
490
- raise ValueError (
491
- self ._file_type_error_msg (data_type )
492
- )
473
+ raise ValueError (self ._file_type_error_msg (data_type ))
493
474
494
475
def get_manifest (
495
476
self ,
@@ -498,7 +479,8 @@ def get_manifest(
498
479
"""Return entire JumpStart models manifest."""
499
480
manifest_dict = self ._content_cache .get (
500
481
JumpStartCachedContentKey (
501
- MODEL_TYPE_TO_MANIFEST_MAP [model_type ], self ._manifest_file_s3_map [model_type ])
482
+ MODEL_TYPE_TO_MANIFEST_MAP [model_type ], self ._manifest_file_s3_map [model_type ]
483
+ )
502
484
)[0 ].formatted_content
503
485
manifest = list (manifest_dict .values ()) # type: ignore
504
486
return manifest
@@ -555,16 +537,14 @@ def _select_version(
555
537
except InvalidSpecifier :
556
538
raise KeyError (f"Bad semantic version: { version_str } " )
557
539
available_versions_filtered = list (spec .filter (available_versions ))
558
- return (
559
- str (max (available_versions_filtered )) if available_versions_filtered != [] else None
560
- )
540
+ return str (max (available_versions_filtered )) if available_versions_filtered != [] else None
561
541
562
542
def _get_header_impl (
563
543
self ,
564
544
model_id : str ,
565
545
semantic_version_str : str ,
566
546
attempt : int = 0 ,
567
- model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS
547
+ model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS ,
568
548
) -> JumpStartModelHeader :
569
549
"""Lower-level function to return header.
570
550
@@ -587,7 +567,8 @@ def _get_header_impl(
587
567
588
568
manifest = self ._content_cache .get (
589
569
JumpStartCachedContentKey (
590
- MODEL_TYPE_TO_MANIFEST_MAP [model_type ], self ._manifest_file_s3_map [model_type ])
570
+ MODEL_TYPE_TO_MANIFEST_MAP [model_type ], self ._manifest_file_s3_map [model_type ]
571
+ )
591
572
)[0 ].formatted_content
592
573
593
574
try :
@@ -603,7 +584,7 @@ def get_specs(
603
584
self ,
604
585
model_id : str ,
605
586
version_str : str ,
606
- model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS
587
+ model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS ,
607
588
) -> JumpStartModelSpecs :
608
589
"""Return specs for a given JumpStart model ID and semantic version.
609
590
@@ -616,16 +597,12 @@ def get_specs(
616
597
header = self .get_header (model_id , version_str , model_type )
617
598
spec_key = header .spec_key
618
599
specs , cache_hit = self ._content_cache .get (
619
- JumpStartCachedContentKey (
620
- MODEL_TYPE_TO_SPECS_MAP [model_type ], spec_key
621
- )
600
+ JumpStartCachedContentKey (MODEL_TYPE_TO_SPECS_MAP [model_type ], spec_key )
622
601
)
623
602
624
603
if not cache_hit and "*" in version_str :
625
604
JUMPSTART_LOGGER .warning (
626
- get_wildcard_model_version_msg (
627
- header .model_id , version_str , header .version
628
- )
605
+ get_wildcard_model_version_msg (header .model_id , version_str , header .version )
629
606
)
630
607
return specs .formatted_content
631
608
0 commit comments