Skip to content

feat: Adding scan and tagging utility #4499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
73ce401
fix: Adding new utils
chrstfu Mar 12, 2024
fbd233f
Merge branch 'curated_hub_tagris' into curated_hub_tagris_copy
chrstfu Mar 12, 2024
6d43cbe
feat: Adding Curated Hub scanning feature
chrstfu Mar 12, 2024
2ec0ca5
fix: Refactoring
chrstfu Mar 13, 2024
c1519b8
fix: Removing test values
chrstfu Mar 13, 2024
c2d4a17
fix: Refactoring
chrstfu Mar 13, 2024
48d6325
fix: removing lru cache temporarily
chrstfu Mar 13, 2024
0947840
fix: Adding unit tests
chrstfu Mar 13, 2024
5f8119b
fix: renaming
chrstfu Mar 13, 2024
395e99c
fix: Initial refactorings
chrstfu Mar 14, 2024
c957e71
fix: Adding more alterations
chrstfu Mar 15, 2024
453bed4
fix: Addressing unit tests
chrstfu Mar 15, 2024
9e703a9
fix: Adding more unittests
chrstfu Mar 15, 2024
7cda304
fix: Add tests
chrstfu Mar 15, 2024
9dbe3d6
fix: Adding list to scan input
chrstfu Mar 15, 2024
554dd20
fix: typo
chrstfu Mar 15, 2024
9255849
fix: Addressing naming comments
chrstfu Mar 15, 2024
6d5f599
fix: changing from string
chrstfu Mar 15, 2024
77d7b50
Merge branch 'master-jumpstart-curated-hub' into curated_hub_tagris_copy
chrstfu Mar 18, 2024
0e5d5b4
fix: formatter
chrstfu Mar 18, 2024
4b82389
fix: linter
chrstfu Mar 18, 2024
42efbaa
fix: linter
chrstfu Mar 18, 2024
a73b5b2
fix: linters
chrstfu Mar 18, 2024
4252d86
fix: linter
chrstfu Mar 18, 2024
0deb20e
fix: linting
chrstfu Mar 18, 2024
d21fd41
fix: more linting :/
chrstfu Mar 18, 2024
6b107a2
fix: linting
chrstfu Mar 18, 2024
29a0740
fix: more linting :/
chrstfu Mar 19, 2024
054a8cf
Merge branch 'master-jumpstart-curated-hub' into curated_hub_tagris_copy
chrstfu Mar 19, 2024
bddf5e3
fix: linting
chrstfu Mar 19, 2024
bd02739
fix: tests
chrstfu Mar 19, 2024
b33a63b
fix: Adding __init__.py
chrstfu Mar 19, 2024
e1da9fc
fix: linting
chrstfu Mar 19, 2024
4c829c9
fix: linting
chrstfu Mar 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 69 additions & 40 deletions src/sagemaker/jumpstart/curated_hub/curated_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""This module provides the JumpStart Curated Hub class."""
from __future__ import absolute_import
from concurrent import futures
from functools import lru_cache
from datetime import datetime
import json
import traceback
Expand Down Expand Up @@ -47,17 +48,24 @@
create_hub_bucket_if_it_does_not_exist,
generate_default_hub_bucket_name,
create_s3_object_reference_from_uri,
tag_hub_content,
get_jumpstart_model_and_version,
find_unsupported_flags_for_hub_content_versions,
summary_list_from_list_api_response,
)
from sagemaker.jumpstart.curated_hub.types import (
HubContentDocument_v2,
JumpStartModelInfo,
S3ObjectLocation,
HubContentSummary,
)

from sagemaker.utils import TagsDict

class CuratedHub:
"""Class for creating and managing a curated JumpStart hub"""

_list_hubs_cache: Dict[str, Any] = None

def __init__(
self,
hub_name: str,
Expand Down Expand Up @@ -143,18 +151,22 @@ def describe(self) -> DescribeHubResponse:

return hub_description

def list_models(self, **kwargs) -> Dict[str, Any]:
"""Lists the models in this Curated Hub

def list_models(self, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any]]:
Copy link
Member

@evakravi evakravi Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is clear_cache used anywhere else in this module?

personally i'm not a fan of this, i'd prefer if you use @lru_cache annotation, it's much simpler imo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdyt @bencrabtree?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a preference but curious how we can force a fetch with the lru_cache decorator. We added this to improve performance in cases where we don't want to make this call again- it's not particularly an LRU cache

"""Lists the models in this Curated Hub.

This function caches the models in local memory

**kwargs: Passed to invocation of ``Session:list_hub_contents``.
"""
# TODO: Validate kwargs and fast-fail?

hub_content_summaries = self._sagemaker_session.list_hub_contents(
if clear_cache:
self._list_hubs_cache = None
if self._list_hubs_cache is None:
hub_content_summaries = self._sagemaker_session.list_hub_contents(
hub_name=self.hub_name, hub_content_type=HubContentType.MODEL, **kwargs
)
# TODO: Handle pagination
return hub_content_summaries
)
self._list_hubs_cache = hub_content_summaries
return self._list_hubs_cache

def describe_model(
self, model_name: str, model_version: str = "*"
Expand Down Expand Up @@ -205,37 +217,12 @@ def _populate_latest_model_version(self, model: Dict[str, str]) -> Dict[str, str
)
return {"model_id": model["model_id"], "version": model_specs.version}

def _get_jumpstart_models_in_hub(self) -> List[Dict[str, Any]]:
"""Returns list of `HubContent` that have been created from a JumpStart model."""
hub_models = self.list_models()

js_models_in_hub = []
for hub_model in hub_models["HubContentSummaries"]:
# TODO: extract both in one pass
jumpstart_model_id = next(
(
tag
for tag in hub_model["search_keywords"]
if tag.startswith(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX)
),
None,
)
jumpstart_model_version = next(
(
tag
for tag in hub_model["search_keywords"]
if tag.startswith(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX)
),
None,
)

if jumpstart_model_id and jumpstart_model_version:
js_models_in_hub.append(hub_model)

return js_models_in_hub
def _get_jumpstart_models_in_hub(self) -> List[HubContentSummary]:
hub_models = summary_list_from_list_api_response(self.list_models())
return [model for model in hub_models if get_jumpstart_model_and_version(model) is not None]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would they be None?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and would we want to log a warning in that case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model ID and version from this function comes from the SearchKeywords, so they could be None.

Theoretically, we could change it to is_jumpstart_model since we don't use the model ID and version from the tags; we use the HubContentName and HubContentVersion as the SOT based on the algorithm. That may help with clarity, wdyt @bencrabtree?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd need to use the get_* function in order to retrieve the values for other logic in sync, but I don't mind making another helper util that uses the get_* function internally


def _determine_models_to_sync(
self, model_list: List[JumpStartModelInfo], models_in_hub: Dict[str, Any]
self, model_list: List[JumpStartModelInfo], models_in_hub: Dict[str, HubContentSummary]
) -> List[JumpStartModelInfo]:
"""Determines which models from `sync` params to sync into the CuratedHub.

Expand All @@ -257,7 +244,7 @@ def _determine_models_to_sync(

if matched_model:
model_version = Version(model.version)
hub_model_version = Version(matched_model["version"])
hub_model_version = Version(matched_model.hub_content_version)

# 1. Model version exists in Hub, pass
if hub_model_version == model_version:
Expand Down Expand Up @@ -302,7 +289,7 @@ def sync(self, model_list: List[Dict[str, str]]):
model_version_list.append(JumpStartModelInfo(model["model_id"], model["version"]))

js_models_in_hub = self._get_jumpstart_models_in_hub()
mapped_models_in_hub = {model["name"]: model for model in js_models_in_hub}
mapped_models_in_hub = {model.hub_content_name: model for model in js_models_in_hub}

models_to_sync = self._determine_models_to_sync(model_version_list, mapped_models_in_hub)
JUMPSTART_LOGGER.warning(
Expand Down Expand Up @@ -415,3 +402,45 @@ def _fetch_studio_specs(self, model_specs: JumpStartModelSpecs) -> Dict[str, Any
Bucket=utils.get_jumpstart_content_bucket(self.region), Key=key
)
return json.loads(response["Body"].read().decode("utf-8"))


def scan_and_tag_models(self, model_list: List[Dict[str, str]] = None) -> None:
"""Scans the Hub for JumpStart models and tags the HubContent.

If the scan detects a model is deprecated or vulnerable, it will tag the HubContent.
The tags that will be added are based off the specifications in the JumpStart public hub:
1. "deprecated_versions" -> If the public hub model is deprecated
2. "inference_vulnerable_versions" -> If the public hub model has inference vulnerabilities
3. "training_vulnerable_versions" -> If the public hub model has training vulnerabilities

The tag value will be a list of versions in the Curated Hub that fall under those keys.
For example, if model_a version_a is deprecated and inference is vulnerable, the
HubContent for `model_a` will have tags [{"deprecated_versions": [version_a]},
{"inference_vulnerable_versions": [version_a]}]

If models are passed in,
"""
JUMPSTART_LOGGER.info(
"Tagging models in hub: %s", self.hub_name
)
if self._is_invalid_model_list_input(model_list):
raise ValueError(
"Model list should be a list of objects with values 'model_id',",
"and optional 'version'.",
)

models_to_scan = model_list if model_list else self.list_models()
js_models_in_hub = [model for model in models_to_scan if get_jumpstart_model_and_version(model) is not None]
for model in js_models_in_hub:
tags_to_add: List[TagsDict] = find_unsupported_flags_for_hub_content_versions(
hub_name=self.hub_name,
hub_content_name=model.hub_content_name,
region=self.region,
session=self._sagemaker_session
)
tag_hub_content(
hub_content_arn=model.hub_content_arn,
tags=tags_to_add,
session=self._sagemaker_session
)
JUMPSTART_LOGGER.info("Tagging complete!")
24 changes: 21 additions & 3 deletions src/sagemaker/jumpstart/curated_hub/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,32 @@
# language governing permissions and limitations under the License.
"""This module stores types related to SageMaker JumpStart CuratedHub."""
from __future__ import absolute_import
from typing import Dict, Any, Optional
from typing import Dict, Any, Optional, List
from enum import Enum
from dataclasses import dataclass
from datetime import datetime

from sagemaker.jumpstart.types import JumpStartDataHolderType, JumpStartModelSpecs
from sagemaker.jumpstart.types import JumpStartDataHolderType, JumpStartModelSpecs, HubContentType

class CuratedHubUnsupportedFlag(str, Enum):
"""Enum class for Curated Hub tag names."""
DEPRECATED_VERSIONS = "deprecated_versions"
TRAINING_VULNERABLE_VERSIONS = "training_vulnerable_versions"
INFERENCE_VULNERABLE_VERSIONS = "inference_vulnerable_versions"

@dataclass
class HubContentSummary:
"""Dataclass to store HubContentSummary from List APIs."""
hub_content_arn: str
hub_content_name: str
hub_content_version: str
hub_content_type: HubContentType
document_schema_version: str
hub_content_status: str
creation_time: str
hub_content_display_name: str = None
hub_content_description: str = None
hub_content_search_keywords: List[str] = None

@dataclass
class S3ObjectLocation:
Expand All @@ -38,7 +57,6 @@ def get_uri(self) -> str:
"""Returns the s3 URI"""
return f"s3://{self.bucket}/{self.key}"


@dataclass
class JumpStartModelInfo:
"""Helper class for storing JumpStart model info."""
Expand Down
134 changes: 133 additions & 1 deletion src/sagemaker/jumpstart/curated_hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,35 @@
import re
from typing import Optional
from sagemaker.jumpstart.curated_hub.types import S3ObjectLocation
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
from sagemaker.s3_utils import parse_s3_url
from sagemaker.session import Session
from sagemaker.utils import aws_partition
from typing import Optional, Dict, List, Any, Set
from botocore.exceptions import ClientError
from sagemaker.jumpstart.types import (
HubContentType,
HubArnExtractedInfo,
HubArnExtractedInfo
)
from sagemaker.jumpstart.curated_hub.types import (
CuratedHubUnsupportedFlag,
HubContentSummary,
JumpStartModelInfo
)
from sagemaker.jumpstart import constants
from sagemaker.jumpstart import utils
from sagemaker.session import Session
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.curated_hub.constants import (
JUMPSTART_HUB_MODEL_ID_TAG_PREFIX,
JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX,
TASK_TAG_PREFIX,
FRAMEWORK_TAG_PREFIX,
)
from sagemaker.utils import (
format_tags,
TagsDict
)


def get_info_from_hub_resource_arn(
Expand Down Expand Up @@ -164,3 +185,114 @@ def create_hub_bucket_if_it_does_not_exist(
)

return bucket_name

def tag_hub_content(hub_content_arn: str, tags: List[TagsDict], session: Session) -> None:
session.add_tags(
ResourceArn=hub_content_arn,
Tags=tags
)
JUMPSTART_LOGGER.info(f"Added tags to HubContentArn %s: %s", hub_content_arn, TagsDict)

def find_unsupported_flags_for_hub_content_versions(hub_name: str, hub_content_name: str, region: str, session: Session) -> List[TagsDict]:
"""Finds the JumpStart public hub model for a HubContent and calculates relevant tags.

Since tags are the same for all versions of a HubContent, these tags will map from the key to a list of versions impacted.
For example, if certain public hub model versions are deprecated,
this utility will return a `deprecated` tag mapped to the deprecated versions for the HubContent.
"""
list_versions_response = session.list_hub_content_versions(
hub_name=hub_name,
hub_content_type=HubContentType.MODEL,
hub_content_name=hub_content_name
)
hub_content_versions: List[HubContentSummary] = summary_list_from_list_api_response(list_versions_response)

unsupported_hub_content_versions_map: Dict[str, List[str]] = {}
for hub_content_version_summary in hub_content_versions:
jumpstart_model = get_jumpstart_model_and_version(hub_content_version_summary)
if jumpstart_model is None:
continue
tag_names_to_add: List[CuratedHubUnsupportedFlag] = find_unsupported_flags_for_model_version(
model_id=jumpstart_model.model_id,
version=jumpstart_model.version,
region=region,
session=session
)

for tag_name in tag_names_to_add:
if tag_name not in unsupported_hub_content_versions_map:
unsupported_hub_content_versions_map[tag_name.value] = []
unsupported_hub_content_versions_map[tag_name.value].append(hub_content_version_summary.hub_content_version)

return format_tags(unsupported_hub_content_versions_map)


def find_unsupported_flags_for_model_version(model_id: str, version: str, region: str, session: Session) -> List[CuratedHubUnsupportedFlag]:
"""Finds relevant CuratedHubTags for a version of a JumpStart public hub model.

For example, if the public hub model is deprecated, this utility will return a `deprecated` tag.
Since tags are the same for all versions of a HubContent, these tags will map from the key to a list of versions impacted.
"""
flags_to_add: List[CuratedHubUnsupportedFlag] = []
jumpstart_model_specs = utils.verify_model_region_and_return_specs(
model_id=model_id,
version=version,
region=region,
scope=JumpStartScriptScope.INFERENCE,
tolerate_vulnerable_model = True,
tolerate_deprecated_model = True,
sagemaker_session=session,
)

if (jumpstart_model_specs.deprecated):
flags_to_add.append(CuratedHubUnsupportedFlag.DEPRECATED_VERSIONS)
if (jumpstart_model_specs.inference_vulnerable):
flags_to_add.append(CuratedHubUnsupportedFlag.INFERENCE_VULNERABLE_VERSIONS)
if (jumpstart_model_specs.training_vulnerable):
flags_to_add.append(CuratedHubUnsupportedFlag.TRAINING_VULNERABLE_VERSIONS)

return flags_to_add



def get_jumpstart_model_and_version(hub_content_summary: HubContentSummary) -> Optional[JumpStartModelInfo]:
jumpstart_model_id_tag = next(
(
tag
for tag in hub_content_summary.hub_content_search_keywords
if tag.startswith(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX)
),
None,
)
jumpstart_model_version_tag = next(
(
tag
for tag in hub_content_summary.hub_content_search_keywords
if tag.startswith(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX)
),
None,
)

if jumpstart_model_id_tag is None or jumpstart_model_version_tag is None:
return None
jumpstart_model_id = jumpstart_model_id_tag[len(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX):] # Need to remove the tag_prefix and ":"
jumpstart_model_version = jumpstart_model_version_tag[len(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX):]
return JumpStartModelInfo(model_id=jumpstart_model_id, version=jumpstart_model_version)

def summary_from_list_api_response(hub_content_summary: Dict[str, Any]) -> HubContentSummary:
return HubContentSummary(
hub_content_arn=hub_content_summary.get("HubContentArn"),
hub_content_name=hub_content_summary.get("HubContentName"),
hub_content_version=hub_content_summary.get("HubContentVersion"),
hub_content_type=hub_content_summary.get("HubContentType"),
document_schema_version=hub_content_summary.get("DocumentSchemaVersion"),
hub_content_status=hub_content_summary.get("HubContentStatus"),
hub_content_display_name=hub_content_summary.get("HubContentDisplayName"),
hub_content_description=hub_content_summary.get("HubContentDescription"),
hub_content_search_keywords=hub_content_summary.get("HubContentSearchKeywords"),
creation_time=hub_content_summary.get("CreationTime")
)

def summary_list_from_list_api_response(list_hub_contents_response: Dict[str, Any]) -> List[HubContentSummary]:
return list(map(summary_from_list_api_response, list_hub_contents_response["HubContentSummaries"]))

Loading