15
15
import re
16
16
from typing import Optional
17
17
from sagemaker .jumpstart .curated_hub .types import S3ObjectLocation
18
+ from sagemaker .jumpstart .constants import JUMPSTART_LOGGER
18
19
from sagemaker .s3_utils import parse_s3_url
19
20
from sagemaker .session import Session
20
21
from sagemaker .utils import aws_partition
25
26
HubArnExtractedInfo
26
27
)
27
28
from sagemaker .jumpstart .curated_hub .types import (
28
- CuratedHubTag ,
29
- CuratedHubTagName ,
29
+ CuratedHubUnsupportedFlag ,
30
30
HubContentSummary ,
31
31
JumpStartModelInfo
32
32
)
40
40
TASK_TAG_PREFIX ,
41
41
FRAMEWORK_TAG_PREFIX ,
42
42
)
43
- from uuid import uuid4
43
+ from sagemaker .utils import (
44
+ format_tags ,
45
+ TagsDict
46
+ )
44
47
45
48
46
49
def get_info_from_hub_resource_arn (
@@ -183,13 +186,14 @@ def create_hub_bucket_if_it_does_not_exist(
183
186
184
187
return bucket_name
185
188
186
- def tag_hub_content (hub_content_arn : str , tags : List [CuratedHubTag ], session : Session ) -> None :
189
+ def tag_hub_content (hub_content_arn : str , tags : List [TagsDict ], session : Session ) -> None :
187
190
session .add_tags (
188
191
ResourceArn = hub_content_arn ,
189
- Tags = [ tag_to_add_tags_api_call ( tag ) for tag in tags ]
192
+ Tags = str ( tags )
190
193
)
194
+ JUMPSTART_LOGGER .info (f"Added tags to HubContentArn %s: %s" , hub_content_arn , TagsDict )
191
195
192
- def find_jumpstart_tags_for_hub_content (hub_name : str , hub_content_name : str , region : str , session : Session ) -> List [CuratedHubTag ]:
196
+ def find_unsupported_flags_for_hub_content_versions (hub_name : str , hub_content_name : str , region : str , session : Session ) -> List [TagsDict ]:
193
197
"""Finds the JumpStart public hub model for a HubContent and calculates relevant tags.
194
198
195
199
Since tags are the same for all versions of a HubContent, these tags will map from the key to a list of versions impacted.
@@ -203,33 +207,33 @@ def find_jumpstart_tags_for_hub_content(hub_name: str, hub_content_name: str, re
203
207
)
204
208
hub_content_versions : List [HubContentSummary ] = summary_list_from_list_api_response (list_versions_response )
205
209
206
- tag_name_to_versions_map : Dict [CuratedHubTagName , List [str ]] = {}
210
+ unsupported_hub_content_versions_map : Dict [str , List [str ]] = {}
207
211
for hub_content_version_summary in hub_content_versions :
208
212
jumpstart_model = get_jumpstart_model_and_version (hub_content_version_summary )
209
213
if jumpstart_model is None :
210
214
continue
211
- tag_names_to_add : List [CuratedHubTagName ] = find_jumpstart_tags_for_model_version (
215
+ tag_names_to_add : List [CuratedHubUnsupportedFlag ] = find_unsupported_flags_for_model_version (
212
216
model_id = jumpstart_model .model_id ,
213
217
version = jumpstart_model .version ,
214
218
region = region ,
215
219
session = session
216
220
)
217
221
218
222
for tag_name in tag_names_to_add :
219
- if tag_name not in tag_name_to_versions_map :
220
- tag_name_to_versions_map [tag_name ] = []
221
- tag_name_to_versions_map [tag_name ].append (hub_content_version_summary .hub_content_version )
223
+ if tag_name not in unsupported_hub_content_versions_map :
224
+ unsupported_hub_content_versions_map [tag_name . value ] = []
225
+ unsupported_hub_content_versions_map [tag_name . value ].append (hub_content_version_summary .hub_content_version )
222
226
223
- return [ CuratedHubTag ( tag_name , str ( versions )) for ( tag_name , versions ) in tag_name_to_versions_map . items ()]
227
+ return format_tags ( unsupported_hub_content_versions_map )
224
228
225
229
226
- def find_jumpstart_tags_for_model_version (model_id : str , version : str , region : str , session : Session ) -> List [CuratedHubTagName ]:
230
+ def find_unsupported_flags_for_model_version (model_id : str , version : str , region : str , session : Session ) -> List [CuratedHubUnsupportedFlag ]:
227
231
"""Finds relevant CuratedHubTags for a version of a JumpStart public hub model.
228
232
229
233
For example, if the public hub model is deprecated, this utility will return a `deprecated` tag.
230
234
Since tags are the same for all versions of a HubContent, these tags will map from the key to a list of versions impacted.
231
235
"""
232
- tags_to_add : List [CuratedHubTagName ] = []
236
+ flags_to_add : List [CuratedHubUnsupportedFlag ] = []
233
237
jumpstart_model_specs = utils .verify_model_region_and_return_specs (
234
238
model_id = model_id ,
235
239
version = version ,
@@ -241,13 +245,13 @@ def find_jumpstart_tags_for_model_version(model_id: str, version: str, region: s
241
245
)
242
246
243
247
if (jumpstart_model_specs .deprecated ):
244
- tags_to_add .append (CuratedHubTagName .DEPRECATED_VERSIONS )
248
+ flags_to_add .append (CuratedHubUnsupportedFlag .DEPRECATED_VERSIONS )
245
249
if (jumpstart_model_specs .inference_vulnerable ):
246
- tags_to_add .append (CuratedHubTagName .INFERENCE_VULNERABLE_VERSIONS )
250
+ flags_to_add .append (CuratedHubUnsupportedFlag .INFERENCE_VULNERABLE_VERSIONS )
247
251
if (jumpstart_model_specs .training_vulnerable ):
248
- tags_to_add .append (CuratedHubTagName .TRAINING_VULNERABLE_VERSIONS )
252
+ flags_to_add .append (CuratedHubUnsupportedFlag .TRAINING_VULNERABLE_VERSIONS )
249
253
250
- return tags_to_add
254
+ return flags_to_add
251
255
252
256
253
257
@@ -272,7 +276,7 @@ def get_jumpstart_model_and_version(hub_content_summary: HubContentSummary) -> O
272
276
if jumpstart_model_id_tag is None or jumpstart_model_version_tag is None :
273
277
return None
274
278
jumpstart_model_id = jumpstart_model_id_tag [len (JUMPSTART_HUB_MODEL_ID_TAG_PREFIX ):] # Need to remove the tag_prefix and ":"
275
- jumpstart_model_version = jumpstart_model_version_tag [len (JUMPSTART_HUB_MODEL_ID_TAG_PREFIX ):]
279
+ jumpstart_model_version = jumpstart_model_version_tag [len (JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX ):]
276
280
return JumpStartModelInfo (model_id = jumpstart_model_id , version = jumpstart_model_version )
277
281
278
282
def summary_from_list_api_response (hub_content_summary : Dict [str , Any ]) -> HubContentSummary :
@@ -292,11 +296,3 @@ def summary_from_list_api_response(hub_content_summary: Dict[str, Any]) -> HubCo
292
296
def summary_list_from_list_api_response (list_hub_contents_response : Dict [str , Any ]) -> List [HubContentSummary ]:
293
297
return list (map (summary_from_list_api_response , list_hub_contents_response ["HubContentSummaries" ]))
294
298
295
- def tag_to_add_tags_api_call (tag : CuratedHubTag ) -> Dict [str , str ]:
296
- return {
297
- 'Key' : tag .key ,
298
- 'Value' : tag .value
299
- }
300
-
301
- def generate_unique_hub_content_model_name (model_id : str ) -> str :
302
- return f"{ model_id } -{ uuid4 ()} "
0 commit comments