-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: Adding scan and tagging utility #4499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
73ce401
fbd233f
6d43cbe
2ec0ca5
c1519b8
c2d4a17
48d6325
0947840
5f8119b
395e99c
c957e71
453bed4
9e703a9
7cda304
9dbe3d6
554dd20
9255849
6d5f599
77d7b50
0e5d5b4
4b82389
42efbaa
a73b5b2
4252d86
0deb20e
d21fd41
6b107a2
29a0740
054a8cf
bddf5e3
bd02739
b33a63b
e1da9fc
4c829c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,16 @@ | |
|
||
from sagemaker.jumpstart.types import JumpStartDataHolderType, JumpStartModelSpecs | ||
|
||
class CuratedHubTagName(str, Enum): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's come up with a better enum name than this. How about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO this name gives flexibility to tag with other values other than There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm I still disagree. This enum should only be for
I'd prefer if we didn't overthink this to be generalizable, and cross that bridge when we get there. This is a private beta, and we need to keep the logic rigid. |
||
"""Enum class for Curated Hub """ | ||
DEPRECATED_VERSIONS_TAG = "deprecated_versions" | ||
TRAINING_VULNERABLE_VERSIONS_TAG = "training_vulnerable_versions" | ||
INFERENCE_VULNERABLE_VERSIONS_TAG = "inference_vulnerable_versions" | ||
chrstfu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@dataclass | ||
class Tag: | ||
chrstfu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
key: CuratedHubTagName | ||
value: str | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, I'll disagree and commit :). But if you're going to keep this, why not make the |
||
|
||
@dataclass | ||
class S3ObjectLocation: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,11 +18,26 @@ | |
from sagemaker.s3_utils import parse_s3_url | ||
from sagemaker.session import Session | ||
from sagemaker.utils import aws_partition | ||
from typing import Optional, Dict, List, Any, Set | ||
from botocore.exceptions import ClientError | ||
from sagemaker.jumpstart.types import ( | ||
HubContentType, | ||
HubArnExtractedInfo, | ||
) | ||
from sagemaker.jumpstart.curated_hub.types import ( | ||
Tag, | ||
CuratedHubTagName | ||
) | ||
from sagemaker.jumpstart import constants | ||
from sagemaker.jumpstart import utils | ||
from sagemaker.session import Session | ||
from sagemaker.jumpstart.enums import JumpStartScriptScope | ||
from sagemaker.jumpstart.curated_hub.constants import ( | ||
JUMPSTART_HUB_MODEL_ID_TAG_PREFIX, | ||
JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX, | ||
TASK_TAG_PREFIX, | ||
FRAMEWORK_TAG_PREFIX, | ||
) | ||
|
||
|
||
def get_info_from_hub_resource_arn( | ||
|
@@ -164,3 +179,97 @@ def create_hub_bucket_if_it_does_not_exist( | |
) | ||
|
||
return bucket_name | ||
|
||
def tag_hub_content(hub_content_arn: str, tags: List[Tag], session: Session) -> List[Dict[str, List[Dict[str, str]]]]: | ||
chrstfu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
responses = [] | ||
for tag in tags: | ||
responses.add(session.add_tags( | ||
chrstfu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ResourceArn=hub_content_arn, | ||
Tags=[ | ||
{ | ||
'Key': tag.key, | ||
'Value': tag.value | ||
}, | ||
] | ||
)) | ||
chrstfu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return responses | ||
|
||
def find_jumpstart_tags_for_model(hub_name: str, hub_content_name: str, region: str, session: Session) -> List[Tag]: | ||
chrstfu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
list_versions_response = session.list_hub_content_versions( | ||
hub_name=hub_name, | ||
hub_content_type='Model', | ||
chrstfu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
hub_content_name=hub_content_name | ||
) | ||
hub_content_versions = list_versions_response["HubContentSummaries"] | ||
|
||
tag_name_to_versions_map: Dict[CuratedHubTagName, List[str]] = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit- rename to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still disagree here Chris. If you make everything super generalized it loses it's context and makes it really difficult for people to read and understand. Have you heard of any use cases where we want to add additional tags? |
||
for hub_content_version_summary in hub_content_versions: | ||
jumpstart_model = get_jumpstart_model_and_version(hub_content_version_summary) | ||
if jumpstart_model["model_id"] is None or jumpstart_model["version"] is None: | ||
continue | ||
tag_names_to_add: List[CuratedHubTagName] = find_jumpstart_tags_for_model_version( | ||
chrstfu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model_id=jumpstart_model["model_id"], | ||
version=jumpstart_model["version"], | ||
region=region, | ||
session=session | ||
) | ||
|
||
for tag_name in tag_names_to_add: | ||
if tag_name not in tag_name_to_versions_map: | ||
tag_name_to_versions_map[tag_name] = [] | ||
tag_name_to_versions_map[tag_name].append(hub_content_version_summary["HubContentVersion"]) | ||
|
||
tags: List[Tag] = [] | ||
for tag_name, versions in tag_name_to_versions_map.items(): | ||
tags.append(Tag( | ||
key=tag_name, | ||
value=str(versions) | ||
)) | ||
|
||
return tags | ||
|
||
|
||
|
||
def find_jumpstart_tags_for_model_version(model_id: str, version: str, region: str, session: Session) -> List[CuratedHubTagName]: | ||
chrstfu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tags_to_add: List[CuratedHubTagName] = [] | ||
specs = utils.verify_model_region_and_return_specs( | ||
model_id=model_id, | ||
version=version, | ||
region=region, | ||
scope=JumpStartScriptScope.INFERENCE, | ||
chrstfu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tolerate_vulnerable_model = True, | ||
tolerate_deprecated_model = True, | ||
sagemaker_session=session, | ||
) | ||
|
||
if (specs.deprecated): | ||
tags_to_add.append(CuratedHubTagName.DEPRECATED_VERSIONS_TAG) | ||
if (specs.inference_vulnerable): | ||
tags_to_add.append(CuratedHubTagName.INFERENCE_VULNERABLE_VERSIONS_TAG) | ||
if (specs.training_vulnerable): | ||
tags_to_add.append(CuratedHubTagName.TRAINING_VULNERABLE_VERSIONS_TAG) | ||
|
||
return tags_to_add | ||
|
||
|
||
|
||
def get_jumpstart_model_and_version(hub_content_summary: Dict[str, Any]) -> Dict[str, Any]: | ||
jumpstart_model_id = next( | ||
( | ||
tag | ||
for tag in hub_content_summary["search_keywords"] | ||
if tag.startswith(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX) | ||
), | ||
None, | ||
) | ||
jumpstart_model_version = next( | ||
( | ||
tag | ||
for tag in hub_content_summary["search_keywords"] | ||
if tag.startswith(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX) | ||
), | ||
None, | ||
) | ||
|
||
return {"model_id": jumpstart_model_id, "version": jumpstart_model_version} | ||
chrstfu marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.