Skip to content

feature: JumpStart CuratedHub class creation and function definitions #4448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,38 @@
import botocore
from packaging.version import Version
from packaging.specifiers import SpecifierSet, InvalidSpecifier
from sagemaker.session import Session
from sagemaker.utilities.cache import LRUCache
from sagemaker.jumpstart.constants import (
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
JUMPSTART_DEFAULT_REGION_NAME,
JUMPSTART_LOGGER,
MODEL_ID_LIST_WEB_URL,
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub
from sagemaker.jumpstart.curated_hub.utils import get_info_from_hub_resource_arn
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
from sagemaker.jumpstart.parameters import (
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
)
from sagemaker.jumpstart import utils
from sagemaker.jumpstart.types import (
JumpStartCachedContentKey,
JumpStartCachedContentValue,
JumpStartModelHeader,
JumpStartModelSpecs,
JumpStartS3FileType,
JumpStartVersionedModelId,
DescribeHubResponse,
DescribeHubContentsResponse,
HubType,
HubContentType,
)
from sagemaker.jumpstart import utils
from sagemaker.utilities.cache import LRUCache
from sagemaker.jumpstart.curated_hub import utils as hub_utils


class JumpStartModelsCache:
Expand All @@ -74,6 +78,7 @@ def __init__(
s3_bucket_name: Optional[str] = None,
s3_client_config: Optional[botocore.config.Config] = None,
s3_client: Optional[boto3.client] = None,
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> None: # fmt: on
"""Initialize a ``JumpStartModelsCache`` instance.

Expand All @@ -95,6 +100,8 @@ def __init__(
s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
Default: None (no config).
s3_client (Optional[boto3.client]): s3 client to use. Default: None.
sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object,
used for SageMaker interactions. Default: Session in region associated with boto3 session.
"""

self._region = region
Expand All @@ -121,6 +128,7 @@ def __init__(
if s3_client_config
else boto3.client("s3", region_name=self._region)
)
self._sagemaker_session = sagemaker_session

def set_region(self, region: str) -> None:
"""Set region for cache. Clears cache after new region is set."""
Expand Down Expand Up @@ -340,32 +348,34 @@ def _retrieval_function(
formatted_content=model_specs
)
if data_type == HubContentType.MODEL:
info = get_info_from_hub_resource_arn(
hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn(
id_info
)
hub = CuratedHub(hub_name=info.hub_name, region=info.region)
hub_content = hub.describe_model(
model_name=info.hub_content_name, model_version=info.hub_content_version
hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need an existence check on the session variable since it could be None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had self._sagemaker_session to default to DEFAULT_JUMPSTART_SAGEMAKER_SESSION unless it changed... I will take a look in the next PR.

hub_name=hub_name,
hub_content_name=model_name,
hub_content_version=model_version,
hub_content_type=data_type
)

model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternative would be to wrap the response in Session with DescribeHubContentResponse. I'd like to hear @evakravi or @akozd opinion on this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should do one or the other, and stick with that throughout the session module. Personally I'd prefer returning the class rather than a dictionary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did make all data classes for HubAPI responses. But it seems that having those as return types in session would cause a circular dependency as we have to import jumpstart.types in session and import session in jumpstart.types, which says it is not a good organization. I will leave the session return types as dictionary and type case somewhere else.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep the constructor for JumpStartModelSpecs as taking just the spec dictionary? The from_hub_content() method could easily be a function outside of that class. This keeps us from going down a route where we may end up violating Single Responsibility Principle on JumpStartModelSpecs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@whittech1 To clarify, are you pointing out the typecase in

model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True)

and instead it should be:

model_specs = JumpStartModelSpecs(hub_model_description, is_hub_content=True)

?

Copy link
Contributor

@whittech1 whittech1 Mar 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am saying that the constructor for JumpStartModelSpecs does not need is_hub_content as a flag. The caller can do:

def convert_hub_content_to_js_model_specs()
    ....
    # return a dict that matches `JumpStartModelSpecs` constructor's dictionary


....

spec_dict = convert_hub_content_to_js_model_spec(DescribeHubContentsResponse(hub_model_description))
model_specs = JumpStartModelSpecs(spec_dict)


utils.emit_logs_based_on_model_specs(
hub_content.content_document,
model_specs,
self.get_region(),
self._s3_client
)
model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True)
return JumpStartCachedContentValue(
formatted_content=model_specs
)
if data_type == HubContentType.HUB:
info = get_info_from_hub_resource_arn(
id_info
)
hub = CuratedHub(hub_name=info.hub_name, region=info.region)
hub_info = hub.describe()
return JumpStartCachedContentValue(formatted_content=hub_info)
if data_type == HubType.HUB:
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can u make the return type a data class in the next PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarifying question: Did you mean something like this?

            hub_info: HubInfo = hub_utils.get_info_from_hub_resource_arn(id_info)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, exactly!

response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name)
hub_description = DescribeHubResponse(response)
return JumpStartCachedContentValue(formatted_content=DescribeHubResponse(hub_description))
raise ValueError(
f"Bad value for key '{key}': must be in",
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubContentType.HUB, HubContentType.MODEL]}"
f"Bad value for key '{key}': must be in ",
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubType.HUB, HubContentType.MODEL]}"
)

def get_manifest(self) -> List[JumpStartModelHeader]:
Expand Down Expand Up @@ -490,7 +500,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]:
hub_arn (str): Arn for the Hub to get info for
"""

details, _ = self._content_cache.get(JumpStartCachedContentKey(HubContentType.HUB, hub_arn))
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubType.HUB, hub_arn))
return details.formatted_content

def clear(self) -> None:
Expand Down
97 changes: 79 additions & 18 deletions src/sagemaker/jumpstart/curated_hub/curated_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@
"""This module provides the JumpStart Curated Hub class."""
from __future__ import absolute_import

from typing import Optional, Dict, Any
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION

from typing import Any, Dict, Optional
from sagemaker.session import Session
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.types import (
DescribeHubResponse,
DescribeHubContentsResponse,
HubContentType,
)
from sagemaker.jumpstart.curated_hub.utils import create_hub_bucket_if_it_does_not_exist


class CuratedHub:
Expand All @@ -25,30 +31,85 @@ class CuratedHub:
def __init__(
self,
hub_name: str,
region: str,
session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
):
"""Instantiates a SageMaker ``CuratedHub``.

Args:
hub_name (str): The name of the Hub to create.
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions.
"""
self.hub_name = hub_name
self.region = region
self._sm_session = session
self.region = sagemaker_session.boto_region_name
self._sagemaker_session = sagemaker_session

def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]:
"""Returns descriptive information about the Hub Model"""
def create(
self,
description: str,
display_name: Optional[str] = None,
search_keywords: Optional[str] = None,
bucket_name: Optional[str] = None,
tags: Optional[str] = None,
) -> Dict[str, str]:
"""Creates a hub with the given description"""

hub_content = self._sm_session.describe_hub_content(
model_name, "Model", self.hub_name, model_version
bucket_name = create_hub_bucket_if_it_does_not_exist(bucket_name, self._sagemaker_session)

return self._sagemaker_session.create_hub(
hub_name=self.hub_name,
hub_description=description,
hub_display_name=display_name,
hub_search_keywords=search_keywords,
hub_bucket_name=bucket_name,
tags=tags,
)

# TODO: Parse HubContent
# TODO: Parse HubContentDocument
def describe(self) -> DescribeHubResponse:
"""Returns descriptive information about the Hub"""

return hub_content
hub_description: DescribeHubResponse = self._sagemaker_session.describe_hub(
hub_name=self.hub_name
)

def describe(self) -> Dict[str, Any]:
"""Returns descriptive information about the Hub"""
return hub_description

hub_info = self._sm_session.describe_hub(hub_name=self.hub_name)
def list_models(self, **kwargs) -> Dict[str, Any]:
"""Lists the models in this Curated Hub

# TODO: Validations?
**kwargs: Passed to invocation of ``Session:list_hub_contents``.
"""
# TODO: Validate kwargs and fast-fail?

hub_content_summaries = self._sagemaker_session.list_hub_contents(
hub_name=self.hub_name, hub_content_type=HubContentType.MODEL, **kwargs
)
# TODO: Handle pagination
return hub_content_summaries

def describe_model(
self, model_name: str, model_version: str = "*"
) -> DescribeHubContentsResponse:
"""Returns descriptive information about the Hub Model"""

hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
hub_name=self.hub_name,
hub_content_name=model_name,
hub_content_version=model_version,
hub_content_type=HubContentType.MODEL,
)

return DescribeHubContentsResponse(hub_content_description)

def delete_model(self, model_name: str, model_version: str = "*") -> None:
"""Deletes a model from this CuratedHub."""
return self._sagemaker_session.delete_hub_content(
hub_content_name=model_name,
hub_content_version=model_version,
hub_content_type=HubContentType.MODEL,
hub_name=self.hub_name,
)

return hub_info
def delete(self) -> None:
"""Deletes this Curated Hub"""
return self._sagemaker_session.delete_hub(self.hub_name)
51 changes: 0 additions & 51 deletions src/sagemaker/jumpstart/curated_hub/types.py

This file was deleted.

51 changes: 47 additions & 4 deletions src/sagemaker/jumpstart/curated_hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from __future__ import absolute_import
import re
from typing import Optional
from sagemaker.jumpstart import constants

from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo
from sagemaker.jumpstart.types import HubContentType
from sagemaker.session import Session
from sagemaker.utils import aws_partition
from sagemaker.jumpstart.types import (
HubContentType,
HubArnExtractedInfo,
)
from sagemaker.jumpstart import constants


def get_info_from_hub_resource_arn(
Expand Down Expand Up @@ -109,3 +110,45 @@ def generate_hub_arn_for_init_kwargs(
else:
hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session)
return hub_arn


def generate_default_hub_bucket_name(
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
"""Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions.

Returns:
str: The name of the default bucket. If the name was not explicitly specified through
the Session or sagemaker_config, the bucket will take the form:
``sagemaker-hubs-{region}-{AWS account ID}``.
"""

region: str = sagemaker_session.boto_region_name
account_id: str = sagemaker_session.account_id()

# TODO: Validate and fast fail

return f"sagemaker-hubs-{region}-{account_id}"


def create_hub_bucket_if_it_does_not_exist(
bucket_name: Optional[str] = None,
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
"""Creates the default SageMaker Hub bucket if it does not exist.

Returns:
str: The name of the default bucket. Takes the form:
``sagemaker-hubs-{region}-{AWS account ID}``.
"""

region: str = sagemaker_session.boto_region_name
if bucket_name is None:
bucket_name: str = generate_default_hub_bucket_name(sagemaker_session)

sagemaker_session._create_s3_bucket_if_it_does_not_exist(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may want to do an ownership check, in case someone snipes the bucket

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 there is also the question of how we would work around this if the bucket is already taken. Perhaps we say this is very unlikely, but we should at least emit a warn in that case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aye aye it seems that the base function does have the ownership check. Wanted to try reusing existing functions but I think we just need to write up our version so we can check ownership.

bucket_name=bucket_name,
region=region,
)

return bucket_name
4 changes: 2 additions & 2 deletions src/sagemaker/jumpstart/session_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from __future__ import absolute_import

from typing import Optional, Tuple
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION

from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn
from sagemaker.session import Session
from sagemaker.utils import aws_partition
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn


def get_model_id_version_from_endpoint(
Expand Down
Loading