Skip to content

Commit 9255849

Browse files
committed
fix: Addressing naming comments
1 parent 554dd20 commit 9255849

File tree

6 files changed

+74
-79
lines changed

6 files changed

+74
-79
lines changed

src/sagemaker/jumpstart/curated_hub/constants.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""This module stores constants related to SageMaker JumpStart CuratedHub."""
1414
from __future__ import absolute_import
1515

16-
JUMPSTART_HUB_MODEL_ID_TAG_PREFIX = "@jumpstart-model-id:"
17-
JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX = "@jumpstart-model-version:"
18-
FRAMEWORK_TAG_PREFIX = "@framework:"
19-
TASK_TAG_PREFIX = "@mltask:"
16+
JUMPSTART_HUB_MODEL_ID_TAG_PREFIX = "@jumpstart-model-id"
17+
JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX = "@jumpstart-model-version"
18+
FRAMEWORK_TAG_PREFIX = "@framework"
19+
TASK_TAG_PREFIX = "@mltask"

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -50,22 +50,22 @@
5050
create_s3_object_reference_from_uri,
5151
tag_hub_content,
5252
get_jumpstart_model_and_version,
53-
find_jumpstart_tags_for_hub_content,
53+
find_unsupported_flags_for_hub_content_versions,
5454
summary_list_from_list_api_response,
5555
)
5656
from sagemaker.jumpstart.curated_hub.types import (
5757
HubContentDocument_v2,
5858
JumpStartModelInfo,
5959
S3ObjectLocation,
60-
CuratedHubTag,
6160
HubContentSummary,
6261
)
63-
64-
LIST_HUB_CACHE = None
62+
from sagemaker.utils import TagsDict
6563

6664
class CuratedHub:
6765
"""Class for creating and managing a curated JumpStart hub"""
6866

67+
_list_hubs_cache: Dict[str, Any] = None
68+
6969
def __init__(
7070
self,
7171
hub_name: str,
@@ -153,25 +153,20 @@ def describe(self) -> DescribeHubResponse:
153153

154154

155155
def list_models(self, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any]]:
156-
"""Lists the models in this Curated Hub
156+
"""Lists the models in this Curated Hub.
157157
158-
**kwargs: Passed to invocation of ``Session:list_hub_contents``.
159-
"""
160-
if clear_cache:
161-
LIST_HUB_CACHE = None
162-
if LIST_HUB_CACHE:
163-
return LIST_HUB_CACHE
164-
return self._list_models(**kwargs)
165-
166-
def _list_models(self, **kwargs) -> Dict[str, Any]:
167-
"""Lists the models in this Curated Hub
158+
This function caches the models in local memory
168159
169160
**kwargs: Passed to invocation of ``Session:list_hub_contents``.
170161
"""
171-
hub_content_summaries = self._sagemaker_session.list_hub_contents(
162+
if clear_cache:
163+
self._list_hubs_cache = None
164+
if self._list_hubs_cache is None:
165+
hub_content_summaries = self._sagemaker_session.list_hub_contents(
172166
hub_name=self.hub_name, hub_content_type=HubContentType.MODEL, **kwargs
173-
)
174-
return hub_content_summaries
167+
)
168+
self._list_hubs_cache = hub_content_summaries
169+
return self._list_hubs_cache
175170

176171
def describe_model(
177172
self, model_name: str, model_version: str = "*"
@@ -242,9 +237,29 @@ def _determine_models_to_sync(
242237
models_to_sync = []
243238
for model in model_list:
244239
matched_model = models_in_hub.get(model.model_id)
245-
if not matched_model or Version(matched_model.hub_content_version) < Version(model.version):
240+
241+
# Model does not exist in Hub, sync
242+
if not matched_model:
246243
models_to_sync.append(model)
247244

245+
if matched_model:
246+
model_version = Version(model.version)
247+
hub_model_version = Version(matched_model.hub_content_version)
248+
249+
# 1. Model version exists in Hub, pass
250+
if hub_model_version == model_version:
251+
pass
252+
253+
# 2. Invalid model version exists in Hub, pass
254+
# This will only happen if something goes wrong in our metadata
255+
if hub_model_version > model_version:
256+
pass
257+
258+
# 3. Old model version exists in Hub, update
259+
if hub_model_version < model_version:
260+
# Check minSDKVersion against current SDK version, emit log
261+
models_to_sync.append(model)
262+
248263
return models_to_sync
249264

250265
def sync(self, model_list: List[Dict[str, str]]):
@@ -416,9 +431,8 @@ def scan_and_tag_models(self, model_list: List[Dict[str, str]] = None) -> None:
416431

417432
models_to_scan = model_list if model_list else self.list_models()
418433
js_models_in_hub = [model for model in models_to_scan if get_jumpstart_model_and_version(model) is not None]
419-
tags_added: Dict[str, List[CuratedHubTag]] = {}
420434
for model in js_models_in_hub:
421-
tags_to_add: List[CuratedHubTag] = find_jumpstart_tags_for_hub_content(
435+
tags_to_add: List[TagsDict] = find_unsupported_flags_for_hub_content_versions(
422436
hub_name=self.hub_name,
423437
hub_content_name=model.hub_content_name,
424438
region=self.region,
@@ -429,9 +443,4 @@ def scan_and_tag_models(self, model_list: List[Dict[str, str]] = None) -> None:
429443
tags=tags_to_add,
430444
session=self._sagemaker_session
431445
)
432-
tags_added.update({model.hub_content_arn: tags_to_add})
433-
434-
output_string = "No tags were added!"
435-
if len(tags_added) > 0:
436-
output_string = f"Added the following tags: {tags_added}"
437-
JUMPSTART_LOGGER.info(output_string)
446+
JUMPSTART_LOGGER.info("Tagging complete!")

src/sagemaker/jumpstart/curated_hub/types.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from sagemaker.jumpstart.types import JumpStartDataHolderType, JumpStartModelSpecs, HubContentType
2121

22-
class CuratedHubTagName(str, Enum):
22+
class CuratedHubUnsupportedFlag(str, Enum):
2323
"""Enum class for Curated Hub tag names."""
2424
DEPRECATED_VERSIONS = "deprecated_versions"
2525
TRAINING_VULNERABLE_VERSIONS = "training_vulnerable_versions"
@@ -38,14 +38,6 @@ class HubContentSummary:
3838
hub_content_display_name: str = None
3939
hub_content_description: str = None
4040
hub_content_search_keywords: List[str] = None
41-
42-
43-
@dataclass
44-
class CuratedHubTag:
45-
"""Dataclass to store Curated Hub-specific tags."""
46-
key: CuratedHubTagName
47-
value: str
48-
4941

5042
@dataclass
5143
class S3ObjectLocation:

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import re
1616
from typing import Optional
1717
from sagemaker.jumpstart.curated_hub.types import S3ObjectLocation
18+
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
1819
from sagemaker.s3_utils import parse_s3_url
1920
from sagemaker.session import Session
2021
from sagemaker.utils import aws_partition
@@ -25,8 +26,7 @@
2526
HubArnExtractedInfo
2627
)
2728
from sagemaker.jumpstart.curated_hub.types import (
28-
CuratedHubTag,
29-
CuratedHubTagName,
29+
CuratedHubUnsupportedFlag,
3030
HubContentSummary,
3131
JumpStartModelInfo
3232
)
@@ -40,7 +40,10 @@
4040
TASK_TAG_PREFIX,
4141
FRAMEWORK_TAG_PREFIX,
4242
)
43-
from uuid import uuid4
43+
from sagemaker.utils import (
44+
format_tags,
45+
TagsDict
46+
)
4447

4548

4649
def get_info_from_hub_resource_arn(
@@ -183,13 +186,14 @@ def create_hub_bucket_if_it_does_not_exist(
183186

184187
return bucket_name
185188

186-
def tag_hub_content(hub_content_arn: str, tags: List[CuratedHubTag], session: Session) -> None:
189+
def tag_hub_content(hub_content_arn: str, tags: List[TagsDict], session: Session) -> None:
187190
session.add_tags(
188191
ResourceArn=hub_content_arn,
189-
Tags=[tag_to_add_tags_api_call(tag) for tag in tags]
192+
Tags=str(tags)
190193
)
194+
JUMPSTART_LOGGER.info(f"Added tags to HubContentArn %s: %s", hub_content_arn, TagsDict)
191195

192-
def find_jumpstart_tags_for_hub_content(hub_name: str, hub_content_name: str, region: str, session: Session) -> List[CuratedHubTag]:
196+
def find_unsupported_flags_for_hub_content_versions(hub_name: str, hub_content_name: str, region: str, session: Session) -> List[TagsDict]:
193197
"""Finds the JumpStart public hub model for a HubContent and calculates relevant tags.
194198
195199
Since tags are the same for all versions of a HubContent, these tags will map from the key to a list of versions impacted.
@@ -203,33 +207,33 @@ def find_jumpstart_tags_for_hub_content(hub_name: str, hub_content_name: str, re
203207
)
204208
hub_content_versions: List[HubContentSummary] = summary_list_from_list_api_response(list_versions_response)
205209

206-
tag_name_to_versions_map: Dict[CuratedHubTagName, List[str]] = {}
210+
unsupported_hub_content_versions_map: Dict[str, List[str]] = {}
207211
for hub_content_version_summary in hub_content_versions:
208212
jumpstart_model = get_jumpstart_model_and_version(hub_content_version_summary)
209213
if jumpstart_model is None:
210214
continue
211-
tag_names_to_add: List[CuratedHubTagName] = find_jumpstart_tags_for_model_version(
215+
tag_names_to_add: List[CuratedHubUnsupportedFlag] = find_unsupported_flags_for_model_version(
212216
model_id=jumpstart_model.model_id,
213217
version=jumpstart_model.version,
214218
region=region,
215219
session=session
216220
)
217221

218222
for tag_name in tag_names_to_add:
219-
if tag_name not in tag_name_to_versions_map:
220-
tag_name_to_versions_map[tag_name] = []
221-
tag_name_to_versions_map[tag_name].append(hub_content_version_summary.hub_content_version)
223+
if tag_name not in unsupported_hub_content_versions_map:
224+
unsupported_hub_content_versions_map[tag_name.value] = []
225+
unsupported_hub_content_versions_map[tag_name.value].append(hub_content_version_summary.hub_content_version)
222226

223-
return [CuratedHubTag(tag_name, str(versions)) for (tag_name, versions) in tag_name_to_versions_map.items()]
227+
return format_tags(unsupported_hub_content_versions_map)
224228

225229

226-
def find_jumpstart_tags_for_model_version(model_id: str, version: str, region: str, session: Session) -> List[CuratedHubTagName]:
230+
def find_unsupported_flags_for_model_version(model_id: str, version: str, region: str, session: Session) -> List[CuratedHubUnsupportedFlag]:
227231
"""Finds relevant CuratedHubTags for a version of a JumpStart public hub model.
228232
229233
For example, if the public hub model is deprecated, this utility will return a `deprecated` tag.
230234
Since tags are the same for all versions of a HubContent, these tags will map from the key to a list of versions impacted.
231235
"""
232-
tags_to_add: List[CuratedHubTagName] = []
236+
flags_to_add: List[CuratedHubUnsupportedFlag] = []
233237
jumpstart_model_specs = utils.verify_model_region_and_return_specs(
234238
model_id=model_id,
235239
version=version,
@@ -241,13 +245,13 @@ def find_jumpstart_tags_for_model_version(model_id: str, version: str, region: s
241245
)
242246

243247
if (jumpstart_model_specs.deprecated):
244-
tags_to_add.append(CuratedHubTagName.DEPRECATED_VERSIONS)
248+
flags_to_add.append(CuratedHubUnsupportedFlag.DEPRECATED_VERSIONS)
245249
if (jumpstart_model_specs.inference_vulnerable):
246-
tags_to_add.append(CuratedHubTagName.INFERENCE_VULNERABLE_VERSIONS)
250+
flags_to_add.append(CuratedHubUnsupportedFlag.INFERENCE_VULNERABLE_VERSIONS)
247251
if (jumpstart_model_specs.training_vulnerable):
248-
tags_to_add.append(CuratedHubTagName.TRAINING_VULNERABLE_VERSIONS)
252+
flags_to_add.append(CuratedHubUnsupportedFlag.TRAINING_VULNERABLE_VERSIONS)
249253

250-
return tags_to_add
254+
return flags_to_add
251255

252256

253257

@@ -272,7 +276,7 @@ def get_jumpstart_model_and_version(hub_content_summary: HubContentSummary) -> O
272276
if jumpstart_model_id_tag is None or jumpstart_model_version_tag is None:
273277
return None
274278
jumpstart_model_id = jumpstart_model_id_tag[len(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX):] # Need to remove the tag_prefix and ":"
275-
jumpstart_model_version = jumpstart_model_version_tag[len(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX):]
279+
jumpstart_model_version = jumpstart_model_version_tag[len(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX):]
276280
return JumpStartModelInfo(model_id=jumpstart_model_id, version=jumpstart_model_version)
277281

278282
def summary_from_list_api_response(hub_content_summary: Dict[str, Any]) -> HubContentSummary:
@@ -292,11 +296,3 @@ def summary_from_list_api_response(hub_content_summary: Dict[str, Any]) -> HubCo
292296
def summary_list_from_list_api_response(list_hub_contents_response: Dict[str, Any]) -> List[HubContentSummary]:
293297
return list(map(summary_from_list_api_response, list_hub_contents_response["HubContentSummaries"]))
294298

295-
def tag_to_add_tags_api_call(tag: CuratedHubTag) -> Dict[str, str]:
296-
return {
297-
'Key': tag.key,
298-
'Value': tag.value
299-
}
300-
301-
def generate_unique_hub_content_model_name(model_id: str) -> str:
302-
return f"{model_id}-{uuid4()}"

tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ def test_sync_kicks_off_parallel_syncs(
177177

178178
hub.sync([model_one, model_two])
179179

180-
# mock_get_model_specs.assert_called_once()
181180
mock_sync_public_models.assert_has_calls(
182181
[
183182
mock.call(JumpStartModelInfo("mock-model-one-huggingface", "*"), 0),

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
from sagemaker.jumpstart.curated_hub import utils
2020
from unittest.mock import patch
2121
from sagemaker.jumpstart.curated_hub.types import (
22-
CuratedHubTag,
23-
CuratedHubTagName,
22+
CuratedHubUnsupportedFlag,
2423
HubContentSummary
2524
)
2625
from sagemaker.jumpstart.types import JumpStartDataHolderType, JumpStartModelSpecs, HubContentType
@@ -187,7 +186,7 @@ def test_find_tags_for_jumpstart_model_version(mock_spec_util):
187186
mock_specs.training_vulnerable = True
188187
mock_spec_util.return_value = mock_specs
189188

190-
tags = utils.find_jumpstart_tags_for_model_version(
189+
tags = utils.find_unsupported_flags_for_model_version(
191190
model_id="test",
192191
version="test",
193192
region="test",
@@ -204,7 +203,7 @@ def test_find_tags_for_jumpstart_model_version(mock_spec_util):
204203
sagemaker_session=mock_sagemaker_session,
205204
)
206205

207-
assert tags == [CuratedHubTagName.DEPRECATED_VERSIONS, CuratedHubTagName.INFERENCE_VULNERABLE_VERSIONS, CuratedHubTagName.TRAINING_VULNERABLE_VERSIONS]
206+
assert tags == [CuratedHubUnsupportedFlag.DEPRECATED_VERSIONS, CuratedHubUnsupportedFlag.INFERENCE_VULNERABLE_VERSIONS, CuratedHubUnsupportedFlag.TRAINING_VULNERABLE_VERSIONS]
208207

209208
@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
210209
def test_find_tags_for_jumpstart_model_version_some_false(mock_spec_util):
@@ -215,7 +214,7 @@ def test_find_tags_for_jumpstart_model_version_some_false(mock_spec_util):
215214
mock_specs.training_vulnerable = False
216215
mock_spec_util.return_value = mock_specs
217216

218-
tags = utils.find_jumpstart_tags_for_model_version(
217+
tags = utils.find_unsupported_flags_for_model_version(
219218
model_id="test",
220219
version="test",
221220
region="test",
@@ -232,7 +231,7 @@ def test_find_tags_for_jumpstart_model_version_some_false(mock_spec_util):
232231
sagemaker_session=mock_sagemaker_session,
233232
)
234233

235-
assert tags == [CuratedHubTagName.DEPRECATED_VERSIONS]
234+
assert tags == [CuratedHubUnsupportedFlag.DEPRECATED_VERSIONS]
236235

237236
@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
238237
def test_find_tags_for_jumpstart_model_version_all_false(mock_spec_util):
@@ -243,7 +242,7 @@ def test_find_tags_for_jumpstart_model_version_all_false(mock_spec_util):
243242
mock_specs.training_vulnerable = False
244243
mock_spec_util.return_value = mock_specs
245244

246-
tags = utils.find_jumpstart_tags_for_model_version(
245+
tags = utils.find_unsupported_flags_for_model_version(
247246
model_id="test",
248247
version="test",
249248
region="test",
@@ -294,7 +293,7 @@ def test_find_all_tags_for_jumpstart_model_filters_non_jumpstart_models(mock_spe
294293
mock_specs.training_vulnerable = True
295294
mock_spec_util.return_value = mock_specs
296295

297-
tags = utils.find_jumpstart_tags_for_hub_content(
296+
tags = utils.find_unsupported_flags_for_hub_content_versions(
298297
hub_name="test",
299298
hub_content_name="test",
300299
region="test",
@@ -308,9 +307,9 @@ def test_find_all_tags_for_jumpstart_model_filters_non_jumpstart_models(mock_spe
308307
)
309308

310309
assert tags == [
311-
CuratedHubTag(key=CuratedHubTagName.DEPRECATED_VERSIONS, value=str(["1.0.0", "2.0.0"])),
312-
CuratedHubTag(key=CuratedHubTagName.INFERENCE_VULNERABLE_VERSIONS, value=str(["1.0.0", "2.0.0"])),
313-
CuratedHubTag(key=CuratedHubTagName.TRAINING_VULNERABLE_VERSIONS, value=str(["1.0.0", "2.0.0"]))
310+
{"Key":CuratedHubUnsupportedFlag.DEPRECATED_VERSIONS.value, "Value":str(["1.0.0", "2.0.0"])},
311+
{"Key":CuratedHubUnsupportedFlag.INFERENCE_VULNERABLE_VERSIONS.value, "Value":str(["1.0.0", "2.0.0"])},
312+
{"Key":CuratedHubUnsupportedFlag.TRAINING_VULNERABLE_VERSIONS.value, "Value":str(["1.0.0", "2.0.0"])}
314313
]
315314

316315
@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")

0 commit comments

Comments
 (0)