13
13
"""This module contains utilities related to SageMaker JumpStart CuratedHub."""
14
14
from __future__ import absolute_import
15
15
import re
16
- from typing import Optional
16
+ from typing import Optional , Dict , List , Any
17
17
from sagemaker .jumpstart .curated_hub .types import S3ObjectLocation
18
18
from sagemaker .jumpstart .constants import JUMPSTART_LOGGER
19
19
from sagemaker .s3_utils import parse_s3_url
20
20
from sagemaker .session import Session
21
21
from sagemaker .utils import aws_partition
22
- from typing import Optional , Dict , List , Any
23
22
from sagemaker .jumpstart .types import HubContentType , HubArnExtractedInfo
24
23
from sagemaker .jumpstart .curated_hub .types import (
25
24
CuratedHubUnsupportedFlag ,
@@ -202,6 +201,18 @@ def find_unsupported_flags_for_hub_content_versions(
202
201
)
203
202
204
203
unsupported_hub_content_versions_map : Dict [str , List [str ]] = {}
204
+ version_to_tag_map = _get_tags_for_all_versions (hub_content_versions , region , session )
205
+ unsupported_hub_content_versions_map = _convert_to_tag_to_versions_map (version_to_tag_map )
206
+
207
+ return format_tags (unsupported_hub_content_versions_map )
208
+
209
+
210
+ def _get_tags_for_all_versions (
211
+ hub_content_versions : List [HubContentSummary ],
212
+ region : str ,
213
+ session : Session ,
214
+ ) -> Dict [str , List [CuratedHubUnsupportedFlag ]]:
215
+ version_to_tags_map : Dict [str , List [CuratedHubUnsupportedFlag ]] = {}
205
216
for hub_content_version_summary in hub_content_versions :
206
217
jumpstart_model = get_jumpstart_model_and_version (hub_content_version_summary )
207
218
if jumpstart_model is None :
@@ -215,14 +226,22 @@ def find_unsupported_flags_for_hub_content_versions(
215
226
session = session ,
216
227
)
217
228
218
- for tag_name in tag_names_to_add :
219
- if tag_name not in unsupported_hub_content_versions_map :
220
- unsupported_hub_content_versions_map [tag_name .value ] = []
221
- unsupported_hub_content_versions_map [tag_name .value ].append (
222
- hub_content_version_summary .hub_content_version
223
- )
229
+ version_to_tags_map [hub_content_version_summary .hub_content_version ] = tag_names_to_add
230
+ return version_to_tags_map
224
231
225
- return format_tags (unsupported_hub_content_versions_map )
232
+
233
+ def _convert_to_tag_to_versions_map (
234
+ version_to_tags_map : Dict [str , List [CuratedHubUnsupportedFlag ]]
235
+ ) -> Dict [CuratedHubUnsupportedFlag , List [str ]]:
236
+ unsupported_hub_content_versions_map : Dict [CuratedHubUnsupportedFlag , List [str ]] = {}
237
+ for version , tags in version_to_tags_map .items ():
238
+ for tag in tags :
239
+ if tag not in unsupported_hub_content_versions_map :
240
+ unsupported_hub_content_versions_map [tag ] = []
241
+ # Versions for a HubContent are unique
242
+ unsupported_hub_content_versions_map [tag ].append (version )
243
+
244
+ return unsupported_hub_content_versions_map
226
245
227
246
228
247
def find_unsupported_flags_for_model_version (
@@ -258,6 +277,7 @@ def find_unsupported_flags_for_model_version(
258
277
def get_jumpstart_model_and_version (
259
278
hub_content_summary : HubContentSummary ,
260
279
) -> Optional [JumpStartModelInfo ]:
280
+ """Retrieves the JumpStart model id and version from the JumpStart tag."""
261
281
jumpstart_model_id_tag = next (
262
282
(
263
283
tag
@@ -287,6 +307,7 @@ def get_jumpstart_model_and_version(
287
307
288
308
289
309
def summary_from_list_api_response (hub_content_summary : Dict [str , Any ]) -> HubContentSummary :
310
+ """Creates a single HubContentSummary from a HubContentSummary from the HubService List APIs."""
290
311
return HubContentSummary (
291
312
hub_content_arn = hub_content_summary .get ("HubContentArn" ),
292
313
hub_content_name = hub_content_summary .get ("HubContentName" ),
@@ -304,6 +325,7 @@ def summary_from_list_api_response(hub_content_summary: Dict[str, Any]) -> HubCo
304
325
def summary_list_from_list_api_response (
305
326
list_hub_contents_response : Dict [str , Any ]
306
327
) -> List [HubContentSummary ]:
328
+ """Creates a HubContentSummary list from either the ListHubContent or ListHubContentVersions API response."""
307
329
return list (
308
330
map (
309
331
summary_from_list_api_response ,
0 commit comments