Skip to content

Commit 054a8cf

Browse files
authored
Merge branch 'master-jumpstart-curated-hub' into curated_hub_tagris_copy
2 parents 29a0740 + a632795 commit 054a8cf

File tree

7 files changed

+110
-49
lines changed

7 files changed

+110
-49
lines changed

src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import (
2525
PublicModelDataAccessor,
2626
)
27+
from sagemaker.jumpstart.curated_hub.utils import is_gated_bucket
2728
from sagemaker.jumpstart.types import JumpStartModelSpecs
2829

2930

@@ -67,6 +68,10 @@ def generate_file_infos_from_model_specs(
6768
files = []
6869
for dependency in HubContentDependencyType:
6970
location: S3ObjectLocation = public_model_data_accessor.get_s3_reference(dependency)
71+
# Training dependencies will return as None if training is unsupported
72+
if not location or is_gated_bucket(location.bucket):
73+
continue
74+
7075
location_type = "prefix" if location.key.endswith("/") else "object"
7176

7277
if location_type == "prefix":

src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module accessors for the SageMaker JumpStart Public Hub."""
1414
from __future__ import absolute_import
15-
from typing import Dict, Any
15+
from typing import Dict, Any, Optional
1616
from sagemaker import model_uris, script_uris
1717
from sagemaker.jumpstart.curated_hub.types import (
1818
HubContentDependencyType,
@@ -21,7 +21,10 @@
2121
from sagemaker.jumpstart.curated_hub.utils import create_s3_object_reference_from_uri
2222
from sagemaker.jumpstart.enums import JumpStartScriptScope
2323
from sagemaker.jumpstart.types import JumpStartModelSpecs
24-
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
24+
from sagemaker.jumpstart.utils import (
25+
get_jumpstart_content_bucket,
26+
get_jumpstart_gated_content_bucket,
27+
)
2528

2629

2730
class PublicModelDataAccessor:
@@ -35,7 +38,11 @@ def __init__(
3538
):
3639
"""Creates a PublicModelDataAccessor."""
3740
self._region = region
38-
self._bucket = get_jumpstart_content_bucket(region)
41+
self._bucket = (
42+
get_jumpstart_gated_content_bucket(region)
43+
if model_specs.gated_bucket
44+
else get_jumpstart_content_bucket(region)
45+
)
3946
self.model_specs = model_specs
4047
self.studio_specs = studio_specs # Necessary for SDK - Studio metadata drift
4148

@@ -44,47 +51,53 @@ def get_s3_reference(self, dependency_type: HubContentDependencyType):
4451
return getattr(self, dependency_type.value)
4552

4653
@property
47-
def inference_artifact_s3_reference(self):
54+
def inference_artifact_s3_reference(self) -> Optional[S3ObjectLocation]:
4855
"""Retrieves s3 reference for model inference artifact"""
4956
return create_s3_object_reference_from_uri(
5057
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.INFERENCE)
5158
)
5259

5360
@property
54-
def training_artifact_s3_reference(self):
61+
def training_artifact_s3_reference(self) -> Optional[S3ObjectLocation]:
5562
"""Retrieves s3 reference for model training artifact"""
63+
if not self.model_specs.training_supported:
64+
return None
5665
return create_s3_object_reference_from_uri(
5766
self._jumpstart_artifact_s3_uri(JumpStartScriptScope.TRAINING)
5867
)
5968

6069
@property
61-
def inference_script_s3_reference(self):
70+
def inference_script_s3_reference(self) -> Optional[S3ObjectLocation]:
6271
"""Retrieves s3 reference for model inference script"""
6372
return create_s3_object_reference_from_uri(
6473
self._jumpstart_script_s3_uri(JumpStartScriptScope.INFERENCE)
6574
)
6675

6776
@property
68-
def training_script_s3_reference(self):
77+
def training_script_s3_reference(self) -> Optional[S3ObjectLocation]:
6978
"""Retrieves s3 reference for model training script"""
79+
if not self.model_specs.training_supported:
80+
return None
7081
return create_s3_object_reference_from_uri(
7182
self._jumpstart_script_s3_uri(JumpStartScriptScope.TRAINING)
7283
)
7384

7485
@property
75-
def default_training_dataset_s3_reference(self):
86+
def default_training_dataset_s3_reference(self) -> S3ObjectLocation:
7687
"""Retrieves s3 reference for s3 directory containing model training datasets"""
77-
return S3ObjectLocation(self._get_bucket_name(), self.__get_training_dataset_prefix())
88+
if not self.model_specs.training_supported:
89+
return None
90+
return S3ObjectLocation(self._get_bucket_name(), self._get_training_dataset_prefix())
7891

7992
@property
80-
def demo_notebook_s3_reference(self):
93+
def demo_notebook_s3_reference(self) -> S3ObjectLocation:
8194
"""Retrieves s3 reference for model demo jupyter notebook"""
8295
framework = self.model_specs.get_framework()
8396
key = f"{framework}-notebooks/{self.model_specs.model_id}-inference.ipynb"
8497
return S3ObjectLocation(self._get_bucket_name(), key)
8598

8699
@property
87-
def markdown_s3_reference(self):
100+
def markdown_s3_reference(self) -> S3ObjectLocation:
88101
"""Retrieves s3 reference for model markdown"""
89102
framework = self.model_specs.get_framework()
90103
key = f"{framework}-metadata/{self.model_specs.model_id}.md"
@@ -94,24 +107,30 @@ def _get_bucket_name(self) -> str:
94107
"""Retrieves s3 bucket"""
95108
return self._bucket
96109

97-
def __get_training_dataset_prefix(self) -> str:
110+
def _get_training_dataset_prefix(self) -> Optional[str]:
98111
"""Retrieves training dataset location"""
99-
return self.studio_specs["defaultDataKey"]
112+
return self.studio_specs.get("defaultDataKey")
100113

101-
def _jumpstart_script_s3_uri(self, model_scope: str) -> str:
114+
def _jumpstart_script_s3_uri(self, model_scope: str) -> Optional[str]:
102115
"""Retrieves JumpStart script s3 location"""
103-
return script_uris.retrieve(
104-
region=self._region,
105-
model_id=self.model_specs.model_id,
106-
model_version=self.model_specs.version,
107-
script_scope=model_scope,
108-
)
116+
try:
117+
return script_uris.retrieve(
118+
region=self._region,
119+
model_id=self.model_specs.model_id,
120+
model_version=self.model_specs.version,
121+
script_scope=model_scope,
122+
)
123+
except ValueError:
124+
return None
109125

110-
def _jumpstart_artifact_s3_uri(self, model_scope: str) -> str:
126+
def _jumpstart_artifact_s3_uri(self, model_scope: str) -> Optional[str]:
111127
"""Retrieves JumpStart artifact s3 location"""
112-
return model_uris.retrieve(
113-
region=self._region,
114-
model_id=self.model_specs.model_id,
115-
model_version=self.model_specs.version,
116-
model_scope=model_scope,
117-
)
128+
try:
129+
return model_uris.retrieve(
130+
region=self._region,
131+
model_id=self.model_specs.model_id,
132+
model_version=self.model_specs.version,
133+
model_scope=model_scope,
134+
)
135+
except ValueError:
136+
return None

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,11 @@ def generate_default_hub_bucket_name(
143143
return f"sagemaker-hubs-{region}-{account_id}"
144144

145145

146-
def create_s3_object_reference_from_uri(s3_uri: str) -> S3ObjectLocation:
146+
def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]:
147147
"""Utiity to help generate an S3 object reference"""
148+
if not s3_uri:
149+
return None
150+
148151
bucket, key = parse_s3_url(s3_uri)
149152

150153
return S3ObjectLocation(
@@ -303,3 +306,8 @@ def get_jumpstart_model_and_version(
303306
len(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX) :
304307
]
305308
return JumpStartModelInfo(model_id=jumpstart_model_id, version=jumpstart_model_version)
309+
310+
311+
def is_gated_bucket(bucket_name: str) -> bool:
312+
"""Returns true if the bucket name is the JumpStart gated bucket."""
313+
return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET

tests/unit/sagemaker/jumpstart/curated_hub/test_filegenerator.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,46 @@ def test_s3_path_file_generator_with_no_objects(s3_client):
127127

128128
s3_client.list_objects_v2.assert_called_once()
129129
assert response == []
130+
131+
132+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
133+
def test_specs_file_generator_training_unsupported(patched_get_model_specs, s3_client):
134+
specs = Mock()
135+
specs.model_id = "mock_model_123"
136+
specs.training_supported = False
137+
specs.gated_bucket = False
138+
specs.hosting_prepacked_artifact_key = "/my/inference/tarball.tgz"
139+
specs.hosting_script_key = "/my/inference/script.py"
140+
141+
response = generate_file_infos_from_model_specs(specs, {}, "us-west-2", s3_client)
142+
143+
assert response == [
144+
FileInfo(
145+
"jumpstart-cache-prod-us-west-2",
146+
"/my/inference/tarball.tgz",
147+
123456789,
148+
"08-14-1997 00:00:00",
149+
),
150+
FileInfo(
151+
"jumpstart-cache-prod-us-west-2",
152+
"/my/inference/script.py",
153+
123456789,
154+
"08-14-1997 00:00:00",
155+
),
156+
]
157+
158+
159+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
160+
def test_specs_file_generator_gated_model(patched_get_model_specs, s3_client):
161+
specs = Mock()
162+
specs.model_id = "mock_model_123"
163+
specs.gated_bucket = True
164+
specs.training_supported = True
165+
specs.hosting_prepacked_artifact_key = "/my/inference/tarball.tgz"
166+
specs.hosting_script_key = "/my/inference/script.py"
167+
specs.training_prepacked_artifact_key = "/my/training/tarball.tgz"
168+
specs.training_script_key = "/my/training/script.py"
169+
170+
response = generate_file_infos_from_model_specs(specs, {}, "us-west-2", s3_client)
171+
172+
assert response == []

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,14 @@ def test_create_hub_bucket_if_it_does_not_exist_hub_arn():
175175
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn
176176

177177

178-
def test_generate_default_hub_bucket_name():
179-
mock_sagemaker_session = Mock()
180-
mock_sagemaker_session.account_id.return_value = "123456789123"
181-
mock_sagemaker_session.boto_region_name = "us-east-1"
178+
def test_is_gated_bucket():
179+
assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True
182180

183-
assert (
184-
utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session)
185-
== "sagemaker-hubs-us-east-1-123456789123"
186-
)
181+
assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-east-1") is True
182+
183+
assert utils.is_gated_bucket("jumpstart-cache-prod-us-west-2") is False
184+
185+
assert utils.is_gated_bucket("") is False
187186

188187

189188
def test_create_hub_bucket_if_it_does_not_exist():

tests/unit/sagemaker/jumpstart/test_accessors.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,6 @@ def test_jumpstart_models_cache_get_model_specs(mock_cache):
9898
)
9999
)
100100

101-
# necessary because accessors is a static module
102-
reload(accessors)
103-
104101

105102
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
106103
def test_jumpstart_proprietary_models_cache_get(mock_cache):

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -254,16 +254,6 @@ def patched_retrieval_function(
254254
)
255255
)
256256

257-
if datatype == HubContentType.MODEL:
258-
_, _, _, model_name, model_version = id_info.split("/")
259-
return JumpStartCachedContentValue(
260-
formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version)
261-
)
262-
263-
# TODO: Implement
264-
if datatype == HubType.HUB:
265-
return None
266-
267257
raise ValueError(f"Bad value for datatype: {datatype}")
268258

269259

0 commit comments

Comments
 (0)