Skip to content

feat: Adding scan and tagging utility #4499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
73ce401
fix: Adding new utils
chrstfu Mar 12, 2024
fbd233f
Merge branch 'curated_hub_tagris' into curated_hub_tagris_copy
chrstfu Mar 12, 2024
6d43cbe
feat: Adding Curated Hub scanning feature
chrstfu Mar 12, 2024
2ec0ca5
fix: Refactoring
chrstfu Mar 13, 2024
c1519b8
fix: Removing test values
chrstfu Mar 13, 2024
c2d4a17
fix: Refactoring
chrstfu Mar 13, 2024
48d6325
fix: removing lru cache temporarily
chrstfu Mar 13, 2024
0947840
fix: Adding unit tests
chrstfu Mar 13, 2024
5f8119b
fix: renaming
chrstfu Mar 13, 2024
395e99c
fix: Initial refactorings
chrstfu Mar 14, 2024
c957e71
fix: Adding more alterations
chrstfu Mar 15, 2024
453bed4
fix: Addressing unit tests
chrstfu Mar 15, 2024
9e703a9
fix: Adding more unittests
chrstfu Mar 15, 2024
7cda304
fix: Add tests
chrstfu Mar 15, 2024
9dbe3d6
fix: Adding list to scan input
chrstfu Mar 15, 2024
554dd20
fix: typo
chrstfu Mar 15, 2024
9255849
fix: Addressing naming comments
chrstfu Mar 15, 2024
6d5f599
fix: changing from string
chrstfu Mar 15, 2024
77d7b50
Merge branch 'master-jumpstart-curated-hub' into curated_hub_tagris_copy
chrstfu Mar 18, 2024
0e5d5b4
fix: formatter
chrstfu Mar 18, 2024
4b82389
fix: linter
chrstfu Mar 18, 2024
42efbaa
fix: linter
chrstfu Mar 18, 2024
a73b5b2
fix: linters
chrstfu Mar 18, 2024
4252d86
fix: linter
chrstfu Mar 18, 2024
0deb20e
fix: linting
chrstfu Mar 18, 2024
d21fd41
fix: more linting :/
chrstfu Mar 18, 2024
6b107a2
fix: linting
chrstfu Mar 18, 2024
29a0740
fix: more linting :/
chrstfu Mar 19, 2024
054a8cf
Merge branch 'master-jumpstart-curated-hub' into curated_hub_tagris_copy
chrstfu Mar 19, 2024
bddf5e3
fix: linting
chrstfu Mar 19, 2024
bd02739
fix: tests
chrstfu Mar 19, 2024
b33a63b
fix: Adding __init__.py
chrstfu Mar 19, 2024
e1da9fc
fix: linting
chrstfu Mar 19, 2024
4c829c9
fix: linting
chrstfu Mar 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 50 additions & 23 deletions src/sagemaker/jumpstart/curated_hub/curated_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""This module provides the JumpStart Curated Hub class."""
from __future__ import absolute_import
from concurrent import futures
from functools import lru_cache
from datetime import datetime
import json
import traceback
Expand Down Expand Up @@ -47,11 +48,15 @@
create_hub_bucket_if_it_does_not_exist,
generate_default_hub_bucket_name,
create_s3_object_reference_from_uri,
tag_hub_content,
get_jumpstart_model_and_version,
find_jumpstart_tags_for_model
)
from sagemaker.jumpstart.curated_hub.types import (
HubContentDocument_v2,
JumpStartModelInfo,
S3ObjectLocation,
Tag
)


Expand Down Expand Up @@ -143,13 +148,13 @@ def describe(self) -> DescribeHubResponse:

return hub_description

def list_models(self, **kwargs) -> Dict[str, Any]:

def list_models(self, clear_cache: bool = True, **kwargs) -> Dict[str, Any]:
"""Lists the models in this Curated Hub

**kwargs: Passed to invocation of ``Session:list_hub_contents``.
"""
# TODO: Validate kwargs and fast-fail?

hub_content_summaries = self._sagemaker_session.list_hub_contents(
hub_name=self.hub_name, hub_content_type=HubContentType.MODEL, **kwargs
)
Expand Down Expand Up @@ -210,27 +215,10 @@ def _get_jumpstart_models_in_hub(self) -> List[Dict[str, Any]]:
hub_models = self.list_models()

js_models_in_hub = []
for hub_model in hub_models["HubContentSummaries"]:
# TODO: extract both in one pass
jumpstart_model_id = next(
(
tag
for tag in hub_model["search_keywords"]
if tag.startswith(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX)
),
None,
)
jumpstart_model_version = next(
(
tag
for tag in hub_model["search_keywords"]
if tag.startswith(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX)
),
None,
)

if jumpstart_model_id and jumpstart_model_version:
js_models_in_hub.append(hub_model)
for hub_model_summary in hub_models["HubContentSummaries"]:
jumpstart_model = get_jumpstart_model_and_version(hub_model_summary)
if jumpstart_model["model_id"] and jumpstart_model["version"]:
js_models_in_hub.append(hub_model_summary)

return js_models_in_hub

Expand Down Expand Up @@ -415,3 +403,42 @@ def _fetch_studio_specs(self, model_specs: JumpStartModelSpecs) -> Dict[str, Any
Bucket=utils.get_jumpstart_content_bucket(self.region), Key=key
)
return json.loads(response["Body"].read().decode("utf-8"))


def scan_and_tag_models(self) -> None:
"""Scans the Hub for JumpStart models and tags the HubContent.

If the scan detects a model is deprecated or vulnerable, it will tag the HubContent.
The tags that will be added are based off the specifications in the JumpStart public hub:
1. "deprecated_versions" -> If the public hub model is deprecated
2. "inference_vulnerable_versions" -> If the public hub model has inference vulnerabilities
3. "training_vulnerable_versions" -> If the public hub model has training vulnerabilities

The tag value will be a list of versions in the Curated Hub that fall under those keys.
For example, if model_a version_a is deprecated and inference is vulnerable, the
HubContent for `model_a` will have tags [{"deprecated_versions": [version_a]},
{"inference_vulnerable_versions": [version_a]}]
"""
JUMPSTART_LOGGER.info(
"Tagging models in hub: %s", self.hub_name
)
models_in_hub: List[Dict[str, Any]] = self._get_jumpstart_models_in_hub()
tags_added: List[Dict[str, List[Dict[str, str]]]] = []
for model in models_in_hub:
tags_to_add: List[Tag] = find_jumpstart_tags_for_model(
hub_name=self.hub_name,
hub_content_name=model["HubContentName"],
region=self.region,
session=self._sagemaker_session
)
tags_added_to_model = tag_hub_content(
hub_content_arn=model["HubContentArn"],
tags=tags_to_add,
session=self._sagemaker_session
)
tags_added.extend(tags_added_to_model)

output_string = "No tags were added!"
if len(tags_added) > 0:
output_string = f"Added the following tags: {tags_added}"
JUMPSTART_LOGGER.info(output_string)
10 changes: 10 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@

from sagemaker.jumpstart.types import JumpStartDataHolderType, JumpStartModelSpecs

class CuratedHubTagName(str, Enum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's come up with a better enum name than this. How about CuratedHubUnsupportedFlag?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this name gives flexibility to tag with other values other than Unsupported values. I would personally vote to keep this as is

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I still disagree. This enum should only be for UnsupportedFlags since it's trying to perform a specific job. If you want to make enums for all tags for CuratedHub models, you can create a new enum, then have a aggregator type

class CuratedHubUnsupportedFlag(str, Enum):
  ...

class CuratedHubBlahBlah(str, Enum):
  ...

CuratedHubTag = CuratedHubUnsupportedFlag | CuratedHubBlahBlah

I'd prefer if we didn't overthink this to be generalizable, and cross that bridge when we get there. This is a private beta, and we need to keep the logic rigid.

"""Enum class for Curated Hub """
DEPRECATED_VERSIONS_TAG = "deprecated_versions"
TRAINING_VULNERABLE_VERSIONS_TAG = "training_vulnerable_versions"
INFERENCE_VULNERABLE_VERSIONS_TAG = "inference_vulnerable_versions"

@dataclass
class Tag:
key: CuratedHubTagName
value: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll disagree and commit :). But if you're going to keep this, why not make the key/value be capitalized so it's compatible with sagemaker:AddTags?


@dataclass
class S3ObjectLocation:
Expand Down
109 changes: 109 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,26 @@
from sagemaker.s3_utils import parse_s3_url
from sagemaker.session import Session
from sagemaker.utils import aws_partition
from typing import Optional, Dict, List, Any, Set
from botocore.exceptions import ClientError
from sagemaker.jumpstart.types import (
HubContentType,
HubArnExtractedInfo,
)
from sagemaker.jumpstart.curated_hub.types import (
Tag,
CuratedHubTagName
)
from sagemaker.jumpstart import constants
from sagemaker.jumpstart import utils
from sagemaker.session import Session
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.curated_hub.constants import (
JUMPSTART_HUB_MODEL_ID_TAG_PREFIX,
JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX,
TASK_TAG_PREFIX,
FRAMEWORK_TAG_PREFIX,
)


def get_info_from_hub_resource_arn(
Expand Down Expand Up @@ -164,3 +179,97 @@ def create_hub_bucket_if_it_does_not_exist(
)

return bucket_name

def tag_hub_content(hub_content_arn: str, tags: List[Tag], session: Session) -> List[Dict[str, List[Dict[str, str]]]]:
responses = []
for tag in tags:
responses.add(session.add_tags(
ResourceArn=hub_content_arn,
Tags=[
{
'Key': tag.key,
'Value': tag.value
},
]
))

return responses

def find_jumpstart_tags_for_model(hub_name: str, hub_content_name: str, region: str, session: Session) -> List[Tag]:
list_versions_response = session.list_hub_content_versions(
hub_name=hub_name,
hub_content_type='Model',
hub_content_name=hub_content_name
)
hub_content_versions = list_versions_response["HubContentSummaries"]

tag_name_to_versions_map: Dict[CuratedHubTagName, List[str]] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit- rename to unsupported_hub_content_versions_map

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still disagree here Chris. If you make everything super generalized it loses it's context and makes it really difficult for people to read and understand. Have you heard of any use cases where we want to add additional tags?

for hub_content_version_summary in hub_content_versions:
jumpstart_model = get_jumpstart_model_and_version(hub_content_version_summary)
if jumpstart_model["model_id"] is None or jumpstart_model["version"] is None:
continue
tag_names_to_add: List[CuratedHubTagName] = find_jumpstart_tags_for_model_version(
model_id=jumpstart_model["model_id"],
version=jumpstart_model["version"],
region=region,
session=session
)

for tag_name in tag_names_to_add:
if tag_name not in tag_name_to_versions_map:
tag_name_to_versions_map[tag_name] = []
tag_name_to_versions_map[tag_name].append(hub_content_version_summary["HubContentVersion"])

tags: List[Tag] = []
for tag_name, versions in tag_name_to_versions_map.items():
tags.append(Tag(
key=tag_name,
value=str(versions)
))

return tags



def find_jumpstart_tags_for_model_version(model_id: str, version: str, region: str, session: Session) -> List[CuratedHubTagName]:
tags_to_add: List[CuratedHubTagName] = []
specs = utils.verify_model_region_and_return_specs(
model_id=model_id,
version=version,
region=region,
scope=JumpStartScriptScope.INFERENCE,
tolerate_vulnerable_model = True,
tolerate_deprecated_model = True,
sagemaker_session=session,
)

if (specs.deprecated):
tags_to_add.append(CuratedHubTagName.DEPRECATED_VERSIONS_TAG)
if (specs.inference_vulnerable):
tags_to_add.append(CuratedHubTagName.INFERENCE_VULNERABLE_VERSIONS_TAG)
if (specs.training_vulnerable):
tags_to_add.append(CuratedHubTagName.TRAINING_VULNERABLE_VERSIONS_TAG)

return tags_to_add



def get_jumpstart_model_and_version(hub_content_summary: Dict[str, Any]) -> Dict[str, Any]:
jumpstart_model_id = next(
(
tag
for tag in hub_content_summary["search_keywords"]
if tag.startswith(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX)
),
None,
)
jumpstart_model_version = next(
(
tag
for tag in hub_content_summary["search_keywords"]
if tag.startswith(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX)
),
None,
)

return {"model_id": jumpstart_model_id, "version": jumpstart_model_version}
113 changes: 113 additions & 0 deletions tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
from unittest.mock import Mock
from sagemaker.jumpstart.types import HubArnExtractedInfo
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.curated_hub import utils
from unittest.mock import patch
from sagemaker.jumpstart.curated_hub.types import (
Tag,
CuratedHubTagName
)


def test_get_info_from_hub_resource_arn():
Expand Down Expand Up @@ -168,3 +174,110 @@ def test_create_hub_bucket_if_it_does_not_exist():

mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
assert created_hub_bucket_name == bucket_name

@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
def test_find_tags_for_jumpstart_model_version(mock_spec_util):
mock_sagemaker_session = Mock()
mock_specs = Mock()
mock_specs.deprecated = True
mock_specs.inference_vulnerable = True
mock_specs.training_vulnerable = True
mock_spec_util.return_value = mock_specs

tags = utils.find_jumpstart_tags_for_model_version(
model_id="test",
version="test",
region="test",
session=mock_sagemaker_session
)

mock_spec_util.assert_called_once_with(
model_id="test",
version="test",
region="test",
scope=JumpStartScriptScope.INFERENCE,
tolerate_vulnerable_model = True,
tolerate_deprecated_model = True,
sagemaker_session=mock_sagemaker_session,
)

assert tags == [CuratedHubTagName.DEPRECATED_VERSIONS_TAG, CuratedHubTagName.INFERENCE_VULNERABLE_VERSIONS_TAG, CuratedHubTagName.TRAINING_VULNERABLE_VERSIONS_TAG]

@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
def test_find_tags_for_jumpstart_model_version_some_false(mock_spec_util):
mock_sagemaker_session = Mock()
mock_specs = Mock()
mock_specs.deprecated = True
mock_specs.inference_vulnerable = False
mock_specs.training_vulnerable = False
mock_spec_util.return_value = mock_specs

tags = utils.find_jumpstart_tags_for_model_version(
model_id="test",
version="test",
region="test",
session=mock_sagemaker_session
)

mock_spec_util.assert_called_once_with(
model_id="test",
version="test",
region="test",
scope=JumpStartScriptScope.INFERENCE,
tolerate_vulnerable_model = True,
tolerate_deprecated_model = True,
sagemaker_session=mock_sagemaker_session,
)

assert tags == [CuratedHubTagName.DEPRECATED_VERSIONS_TAG]

@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
def test_find_all_tags_for_jumpstart_model(mock_spec_util):
mock_sagemaker_session = Mock()
mock_sagemaker_session.list_hub_content_versions.return_value = {
"HubContentSummaries": [
{
"HubContentVersion": "1.0.0",
"search_keywords": [
"@jumpstart-model-id:model-one-pytorch",
"@jumpstart-model-version:1.0.3",
]
},
{
"HubContentVersion": "2.0.0",
"search_keywords": [
"@jumpstart-model-id:model-four-huggingface",
"@jumpstart-model-version:2.0.2",
]
},
{
"HubContentVersion": "3.0.0",
"search_keywords": []
}
]
}

mock_specs = Mock()
mock_specs.deprecated = True
mock_specs.inference_vulnerable = True
mock_specs.training_vulnerable = True
mock_spec_util.return_value = mock_specs

tags = utils.find_jumpstart_tags_for_model(
hub_name="test",
hub_content_name="test",
region="test",
session=mock_sagemaker_session
)

mock_sagemaker_session.list_hub_content_versions.assert_called_once_with(
hub_name="test",
hub_content_type='Model',
hub_content_name="test",
)

assert tags == [
Tag(key=CuratedHubTagName.DEPRECATED_VERSIONS_TAG, value=str(["1.0.0", "2.0.0"])),
Tag(key=CuratedHubTagName.INFERENCE_VULNERABLE_VERSIONS_TAG, value=str(["1.0.0", "2.0.0"])),
Tag(key=CuratedHubTagName.TRAINING_VULNERABLE_VERSIONS_TAG, value=str(["1.0.0", "2.0.0"]))
]