Skip to content

Commit 42efbaa

Browse files
committed
fix: linter
1 parent 4b82389 commit 42efbaa

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

src/sagemaker/jumpstart/curated_hub/curated_hub.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def _populate_latest_model_version(self, model: Dict[str, str]) -> Dict[str, str
219219
return {"model_id": model["model_id"], "version": model_specs.version}
220220

221221
def _get_jumpstart_models_in_hub(self) -> List[HubContentSummary]:
222+
"""Retrieves all JumpStart models in a private Hub."""
222223
hub_models = summary_list_from_list_api_response(self.list_models())
223224
return [model for model in hub_models if get_jumpstart_model_and_version(model) is not None]
224225

src/sagemaker/jumpstart/curated_hub/utils.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
"""This module contains utilities related to SageMaker JumpStart CuratedHub."""
1414
from __future__ import absolute_import
1515
import re
16-
from typing import Optional
16+
from typing import Optional, Dict, List, Any
1717
from sagemaker.jumpstart.curated_hub.types import S3ObjectLocation
1818
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
1919
from sagemaker.s3_utils import parse_s3_url
2020
from sagemaker.session import Session
2121
from sagemaker.utils import aws_partition
22-
from typing import Optional, Dict, List, Any
2322
from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo
2423
from sagemaker.jumpstart.curated_hub.types import (
2524
CuratedHubUnsupportedFlag,
@@ -202,6 +201,18 @@ def find_unsupported_flags_for_hub_content_versions(
202201
)
203202

204203
unsupported_hub_content_versions_map: Dict[str, List[str]] = {}
204+
version_to_tag_map = _get_tags_for_all_versions(hub_content_versions, region, session)
205+
unsupported_hub_content_versions_map = _convert_to_tag_to_versions_map(version_to_tag_map)
206+
207+
return format_tags(unsupported_hub_content_versions_map)
208+
209+
210+
def _get_tags_for_all_versions(
211+
hub_content_versions: List[HubContentSummary],
212+
region: str,
213+
session: Session,
214+
) -> Dict[str, List[CuratedHubUnsupportedFlag]]:
215+
version_to_tags_map: Dict[str, List[CuratedHubUnsupportedFlag]] = {}
205216
for hub_content_version_summary in hub_content_versions:
206217
jumpstart_model = get_jumpstart_model_and_version(hub_content_version_summary)
207218
if jumpstart_model is None:
@@ -215,14 +226,22 @@ def find_unsupported_flags_for_hub_content_versions(
215226
session=session,
216227
)
217228

218-
for tag_name in tag_names_to_add:
219-
if tag_name not in unsupported_hub_content_versions_map:
220-
unsupported_hub_content_versions_map[tag_name.value] = []
221-
unsupported_hub_content_versions_map[tag_name.value].append(
222-
hub_content_version_summary.hub_content_version
223-
)
229+
version_to_tags_map[hub_content_version_summary.hub_content_version] = tag_names_to_add
230+
return version_to_tags_map
224231

225-
return format_tags(unsupported_hub_content_versions_map)
232+
233+
def _convert_to_tag_to_versions_map(
234+
version_to_tags_map: Dict[str, List[CuratedHubUnsupportedFlag]]
235+
) -> Dict[CuratedHubUnsupportedFlag, List[str]]:
236+
unsupported_hub_content_versions_map: Dict[CuratedHubUnsupportedFlag, List[str]] = {}
237+
for version, tags in version_to_tags_map.items():
238+
for tag in tags:
239+
if tag not in unsupported_hub_content_versions_map:
240+
unsupported_hub_content_versions_map[tag] = []
241+
# Versions for a HubContent are unique
242+
unsupported_hub_content_versions_map[tag].append(version)
243+
244+
return unsupported_hub_content_versions_map
226245

227246

228247
def find_unsupported_flags_for_model_version(
@@ -258,6 +277,7 @@ def find_unsupported_flags_for_model_version(
258277
def get_jumpstart_model_and_version(
259278
hub_content_summary: HubContentSummary,
260279
) -> Optional[JumpStartModelInfo]:
280+
"""Retrieves the JumpStart model id and version from the JumpStart tag."""
261281
jumpstart_model_id_tag = next(
262282
(
263283
tag
@@ -287,6 +307,7 @@ def get_jumpstart_model_and_version(
287307

288308

289309
def summary_from_list_api_response(hub_content_summary: Dict[str, Any]) -> HubContentSummary:
310+
"""Creates a single HubContentSummary from a HubContentSummary from the HubService List APIs."""
290311
return HubContentSummary(
291312
hub_content_arn=hub_content_summary.get("HubContentArn"),
292313
hub_content_name=hub_content_summary.get("HubContentName"),
@@ -304,6 +325,7 @@ def summary_from_list_api_response(hub_content_summary: Dict[str, Any]) -> HubCo
304325
def summary_list_from_list_api_response(
305326
list_hub_contents_response: Dict[str, Any]
306327
) -> List[HubContentSummary]:
328+
"""Creates a HubContentSummary list from either the ListHubContent or ListHubContentVersions API response."""
307329
return list(
308330
map(
309331
summary_from_list_api_response,

0 commit comments

Comments
 (0)