-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: Adding scan and tagging utility #4499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
73ce401
fbd233f
6d43cbe
2ec0ca5
c1519b8
c2d4a17
48d6325
0947840
5f8119b
395e99c
c957e71
453bed4
9e703a9
7cda304
9dbe3d6
554dd20
9255849
6d5f599
77d7b50
0e5d5b4
4b82389
42efbaa
a73b5b2
4252d86
0deb20e
d21fd41
6b107a2
29a0740
054a8cf
bddf5e3
bd02739
b33a63b
e1da9fc
4c829c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
"""This module provides the JumpStart Curated Hub class.""" | ||
from __future__ import absolute_import | ||
from concurrent import futures | ||
from functools import lru_cache | ||
from datetime import datetime | ||
import json | ||
import traceback | ||
|
@@ -47,17 +48,24 @@ | |
create_hub_bucket_if_it_does_not_exist, | ||
generate_default_hub_bucket_name, | ||
create_s3_object_reference_from_uri, | ||
tag_hub_content, | ||
get_jumpstart_model_and_version, | ||
find_unsupported_flags_for_hub_content_versions, | ||
summary_list_from_list_api_response, | ||
) | ||
from sagemaker.jumpstart.curated_hub.types import ( | ||
HubContentDocument_v2, | ||
JumpStartModelInfo, | ||
S3ObjectLocation, | ||
HubContentSummary, | ||
) | ||
|
||
from sagemaker.utils import TagsDict | ||
|
||
class CuratedHub: | ||
"""Class for creating and managing a curated JumpStart hub""" | ||
|
||
_list_hubs_cache: Dict[str, Any] = None | ||
|
||
def __init__( | ||
self, | ||
hub_name: str, | ||
|
@@ -143,18 +151,22 @@ def describe(self) -> DescribeHubResponse: | |
|
||
return hub_description | ||
|
||
def list_models(self, **kwargs) -> Dict[str, Any]: | ||
"""Lists the models in this Curated Hub | ||
|
||
def list_models(self, clear_cache: bool = True, **kwargs) -> List[Dict[str, Any]]: | ||
"""Lists the models in this Curated Hub. | ||
|
||
This function caches the models in local memory | ||
|
||
**kwargs: Passed to invocation of ``Session:list_hub_contents``. | ||
""" | ||
# TODO: Validate kwargs and fast-fail? | ||
|
||
hub_content_summaries = self._sagemaker_session.list_hub_contents( | ||
if clear_cache: | ||
self._list_hubs_cache = None | ||
if self._list_hubs_cache is None: | ||
hub_content_summaries = self._sagemaker_session.list_hub_contents( | ||
hub_name=self.hub_name, hub_content_type=HubContentType.MODEL, **kwargs | ||
) | ||
# TODO: Handle pagination | ||
return hub_content_summaries | ||
) | ||
self._list_hubs_cache = hub_content_summaries | ||
return self._list_hubs_cache | ||
|
||
def describe_model( | ||
self, model_name: str, model_version: str = "*" | ||
|
@@ -205,37 +217,12 @@ def _populate_latest_model_version(self, model: Dict[str, str]) -> Dict[str, str | |
) | ||
return {"model_id": model["model_id"], "version": model_specs.version} | ||
|
||
def _get_jumpstart_models_in_hub(self) -> List[Dict[str, Any]]: | ||
"""Returns list of `HubContent` that have been created from a JumpStart model.""" | ||
hub_models = self.list_models() | ||
|
||
js_models_in_hub = [] | ||
for hub_model in hub_models["HubContentSummaries"]: | ||
# TODO: extract both in one pass | ||
jumpstart_model_id = next( | ||
( | ||
tag | ||
for tag in hub_model["search_keywords"] | ||
if tag.startswith(JUMPSTART_HUB_MODEL_ID_TAG_PREFIX) | ||
), | ||
None, | ||
) | ||
jumpstart_model_version = next( | ||
( | ||
tag | ||
for tag in hub_model["search_keywords"] | ||
if tag.startswith(JUMPSTART_HUB_MODEL_VERSION_TAG_PREFIX) | ||
), | ||
None, | ||
) | ||
|
||
if jumpstart_model_id and jumpstart_model_version: | ||
js_models_in_hub.append(hub_model) | ||
|
||
return js_models_in_hub | ||
def _get_jumpstart_models_in_hub(self) -> List[HubContentSummary]: | ||
hub_models = summary_list_from_list_api_response(self.list_models()) | ||
return [model for model in hub_models if get_jumpstart_model_and_version(model) is not None] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why would they be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and would we want to log a warning in that case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The model ID and version from this function comes from the Theoretically, we could change it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'd need to use the |
||
|
||
def _determine_models_to_sync( | ||
self, model_list: List[JumpStartModelInfo], models_in_hub: Dict[str, Any] | ||
self, model_list: List[JumpStartModelInfo], models_in_hub: Dict[str, HubContentSummary] | ||
) -> List[JumpStartModelInfo]: | ||
"""Determines which models from `sync` params to sync into the CuratedHub. | ||
|
||
|
@@ -257,7 +244,7 @@ def _determine_models_to_sync( | |
|
||
if matched_model: | ||
model_version = Version(model.version) | ||
hub_model_version = Version(matched_model["version"]) | ||
hub_model_version = Version(matched_model.hub_content_version) | ||
|
||
# 1. Model version exists in Hub, pass | ||
if hub_model_version == model_version: | ||
|
@@ -302,7 +289,7 @@ def sync(self, model_list: List[Dict[str, str]]): | |
model_version_list.append(JumpStartModelInfo(model["model_id"], model["version"])) | ||
|
||
js_models_in_hub = self._get_jumpstart_models_in_hub() | ||
mapped_models_in_hub = {model["name"]: model for model in js_models_in_hub} | ||
mapped_models_in_hub = {model.hub_content_name: model for model in js_models_in_hub} | ||
|
||
models_to_sync = self._determine_models_to_sync(model_version_list, mapped_models_in_hub) | ||
JUMPSTART_LOGGER.warning( | ||
|
@@ -415,3 +402,45 @@ def _fetch_studio_specs(self, model_specs: JumpStartModelSpecs) -> Dict[str, Any | |
Bucket=utils.get_jumpstart_content_bucket(self.region), Key=key | ||
) | ||
return json.loads(response["Body"].read().decode("utf-8")) | ||
|
||
|
||
def scan_and_tag_models(self, model_list: List[Dict[str, str]] = None) -> None: | ||
"""Scans the Hub for JumpStart models and tags the HubContent. | ||
|
||
If the scan detects a model is deprecated or vulnerable, it will tag the HubContent. | ||
The tags that will be added are based off the specifications in the JumpStart public hub: | ||
1. "deprecated_versions" -> If the public hub model is deprecated | ||
2. "inference_vulnerable_versions" -> If the public hub model has inference vulnerabilities | ||
3. "training_vulnerable_versions" -> If the public hub model has training vulnerabilities | ||
chrstfu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
The tag value will be a list of versions in the Curated Hub that fall under those keys. | ||
For example, if model_a version_a is deprecated and inference is vulnerable, the | ||
HubContent for `model_a` will have tags [{"deprecated_versions": [version_a]}, | ||
{"inference_vulnerable_versions": [version_a]}] | ||
|
||
If models are passed in, | ||
""" | ||
JUMPSTART_LOGGER.info( | ||
"Tagging models in hub: %s", self.hub_name | ||
) | ||
if self._is_invalid_model_list_input(model_list): | ||
chrstfu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError( | ||
"Model list should be a list of objects with values 'model_id',", | ||
"and optional 'version'.", | ||
) | ||
|
||
models_to_scan = model_list if model_list else self.list_models() | ||
js_models_in_hub = [model for model in models_to_scan if get_jumpstart_model_and_version(model) is not None] | ||
for model in js_models_in_hub: | ||
tags_to_add: List[TagsDict] = find_unsupported_flags_for_hub_content_versions( | ||
hub_name=self.hub_name, | ||
hub_content_name=model.hub_content_name, | ||
region=self.region, | ||
session=self._sagemaker_session | ||
) | ||
tag_hub_content( | ||
hub_content_arn=model.hub_content_arn, | ||
tags=tags_to_add, | ||
session=self._sagemaker_session | ||
) | ||
JUMPSTART_LOGGER.info("Tagging complete!") |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is
clear_cache
used anywhere else in this module?personally i'm not a fan of this, i'd prefer if you use
@lru_cache
annotation, it's much simpler imoThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wdyt @bencrabtree?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a preference but curious how we can force a fetch with the
lru_cache
decorator. We added this to improve performance in cases where we don't want to make this call again- it's not particularly an LRU cache