12
12
# language governing permissions and limitations under the License.
13
13
"""This module accessors for the SageMaker JumpStart Public Hub."""
14
14
from __future__ import absolute_import
15
- from typing import Dict , Any
15
+ from typing import Dict , Any , Optional
16
16
from sagemaker import model_uris , script_uris
17
17
from sagemaker .jumpstart .curated_hub .types import (
18
18
HubContentDependencyType ,
21
21
from sagemaker .jumpstart .curated_hub .utils import create_s3_object_reference_from_uri
22
22
from sagemaker .jumpstart .enums import JumpStartScriptScope
23
23
from sagemaker .jumpstart .types import JumpStartModelSpecs
24
- from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
24
+ from sagemaker .jumpstart .utils import (
25
+ get_jumpstart_content_bucket ,
26
+ get_jumpstart_gated_content_bucket ,
27
+ )
25
28
26
29
27
30
class PublicModelDataAccessor :
@@ -35,7 +38,11 @@ def __init__(
35
38
):
36
39
"""Creates a PublicModelDataAccessor."""
37
40
self ._region = region
38
- self ._bucket = get_jumpstart_content_bucket (region )
41
+ self ._bucket = (
42
+ get_jumpstart_gated_content_bucket (region )
43
+ if model_specs .gated_bucket
44
+ else get_jumpstart_content_bucket (region )
45
+ )
39
46
self .model_specs = model_specs
40
47
self .studio_specs = studio_specs # Necessary for SDK - Studio metadata drift
41
48
@@ -44,47 +51,53 @@ def get_s3_reference(self, dependency_type: HubContentDependencyType):
44
51
return getattr (self , dependency_type .value )
45
52
46
53
@property
47
- def inference_artifact_s3_reference (self ):
54
+ def inference_artifact_s3_reference (self ) -> Optional [ S3ObjectLocation ] :
48
55
"""Retrieves s3 reference for model inference artifact"""
49
56
return create_s3_object_reference_from_uri (
50
57
self ._jumpstart_artifact_s3_uri (JumpStartScriptScope .INFERENCE )
51
58
)
52
59
53
60
@property
54
- def training_artifact_s3_reference (self ):
61
+ def training_artifact_s3_reference (self ) -> Optional [ S3ObjectLocation ] :
55
62
"""Retrieves s3 reference for model training artifact"""
63
+ if not self .model_specs .training_supported :
64
+ return None
56
65
return create_s3_object_reference_from_uri (
57
66
self ._jumpstart_artifact_s3_uri (JumpStartScriptScope .TRAINING )
58
67
)
59
68
60
69
@property
61
- def inference_script_s3_reference (self ):
70
+ def inference_script_s3_reference (self ) -> Optional [ S3ObjectLocation ] :
62
71
"""Retrieves s3 reference for model inference script"""
63
72
return create_s3_object_reference_from_uri (
64
73
self ._jumpstart_script_s3_uri (JumpStartScriptScope .INFERENCE )
65
74
)
66
75
67
76
@property
68
- def training_script_s3_reference (self ):
77
+ def training_script_s3_reference (self ) -> Optional [ S3ObjectLocation ] :
69
78
"""Retrieves s3 reference for model training script"""
79
+ if not self .model_specs .training_supported :
80
+ return None
70
81
return create_s3_object_reference_from_uri (
71
82
self ._jumpstart_script_s3_uri (JumpStartScriptScope .TRAINING )
72
83
)
73
84
74
85
@property
75
- def default_training_dataset_s3_reference (self ):
86
+ def default_training_dataset_s3_reference (self ) -> S3ObjectLocation :
76
87
"""Retrieves s3 reference for s3 directory containing model training datasets"""
77
- return S3ObjectLocation (self ._get_bucket_name (), self .__get_training_dataset_prefix ())
88
+ if not self .model_specs .training_supported :
89
+ return None
90
+ return S3ObjectLocation (self ._get_bucket_name (), self ._get_training_dataset_prefix ())
78
91
79
92
@property
80
- def demo_notebook_s3_reference (self ):
93
+ def demo_notebook_s3_reference (self ) -> S3ObjectLocation :
81
94
"""Retrieves s3 reference for model demo jupyter notebook"""
82
95
framework = self .model_specs .get_framework ()
83
96
key = f"{ framework } -notebooks/{ self .model_specs .model_id } -inference.ipynb"
84
97
return S3ObjectLocation (self ._get_bucket_name (), key )
85
98
86
99
@property
87
- def markdown_s3_reference (self ):
100
+ def markdown_s3_reference (self ) -> S3ObjectLocation :
88
101
"""Retrieves s3 reference for model markdown"""
89
102
framework = self .model_specs .get_framework ()
90
103
key = f"{ framework } -metadata/{ self .model_specs .model_id } .md"
@@ -94,24 +107,30 @@ def _get_bucket_name(self) -> str:
94
107
"""Retrieves s3 bucket"""
95
108
return self ._bucket
96
109
97
- def __get_training_dataset_prefix (self ) -> str :
110
+ def _get_training_dataset_prefix (self ) -> Optional [ str ] :
98
111
"""Retrieves training dataset location"""
99
- return self .studio_specs [ "defaultDataKey" ]
112
+ return self .studio_specs . get ( "defaultDataKey" )
100
113
101
- def _jumpstart_script_s3_uri (self , model_scope : str ) -> str :
114
+ def _jumpstart_script_s3_uri (self , model_scope : str ) -> Optional [ str ] :
102
115
"""Retrieves JumpStart script s3 location"""
103
- return script_uris .retrieve (
104
- region = self ._region ,
105
- model_id = self .model_specs .model_id ,
106
- model_version = self .model_specs .version ,
107
- script_scope = model_scope ,
108
- )
116
+ try :
117
+ return script_uris .retrieve (
118
+ region = self ._region ,
119
+ model_id = self .model_specs .model_id ,
120
+ model_version = self .model_specs .version ,
121
+ script_scope = model_scope ,
122
+ )
123
+ except ValueError :
124
+ return None
109
125
110
- def _jumpstart_artifact_s3_uri (self , model_scope : str ) -> str :
126
+ def _jumpstart_artifact_s3_uri (self , model_scope : str ) -> Optional [ str ] :
111
127
"""Retrieves JumpStart artifact s3 location"""
112
- return model_uris .retrieve (
113
- region = self ._region ,
114
- model_id = self .model_specs .model_id ,
115
- model_version = self .model_specs .version ,
116
- model_scope = model_scope ,
117
- )
128
+ try :
129
+ return model_uris .retrieve (
130
+ region = self ._region ,
131
+ model_id = self .model_specs .model_id ,
132
+ model_version = self .model_specs .version ,
133
+ model_scope = model_scope ,
134
+ )
135
+ except ValueError :
136
+ return None
0 commit comments