-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: JumpStart CuratedHub class creation and function definitions #4448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4922511
f4da2ad
d5937b3
0435ae9
a4c67e8
4799ccb
086bf92
b600dd1
1506147
a8d6664
edb887e
5f24036
b0ce624
0937c74
7c87c52
dddd0a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,34 +21,38 @@ | |
import botocore | ||
from packaging.version import Version | ||
from packaging.specifiers import SpecifierSet, InvalidSpecifier | ||
from sagemaker.session import Session | ||
from sagemaker.utilities.cache import LRUCache | ||
from sagemaker.jumpstart.constants import ( | ||
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, | ||
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, | ||
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, | ||
JUMPSTART_DEFAULT_REGION_NAME, | ||
JUMPSTART_LOGGER, | ||
MODEL_ID_LIST_WEB_URL, | ||
DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) | ||
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub | ||
from sagemaker.jumpstart.curated_hub.utils import get_info_from_hub_resource_arn | ||
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg | ||
from sagemaker.jumpstart.parameters import ( | ||
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, | ||
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, | ||
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, | ||
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, | ||
) | ||
from sagemaker.jumpstart import utils | ||
from sagemaker.jumpstart.types import ( | ||
JumpStartCachedContentKey, | ||
JumpStartCachedContentValue, | ||
JumpStartModelHeader, | ||
JumpStartModelSpecs, | ||
JumpStartS3FileType, | ||
JumpStartVersionedModelId, | ||
DescribeHubResponse, | ||
DescribeHubContentsResponse, | ||
HubType, | ||
HubContentType, | ||
) | ||
from sagemaker.jumpstart import utils | ||
from sagemaker.utilities.cache import LRUCache | ||
from sagemaker.jumpstart.curated_hub import utils as hub_utils | ||
|
||
|
||
class JumpStartModelsCache: | ||
|
@@ -74,6 +78,7 @@ def __init__( | |
s3_bucket_name: Optional[str] = None, | ||
s3_client_config: Optional[botocore.config.Config] = None, | ||
s3_client: Optional[boto3.client] = None, | ||
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> None: # fmt: on | ||
"""Initialize a ``JumpStartModelsCache`` instance. | ||
|
||
|
@@ -95,6 +100,8 @@ def __init__( | |
s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache. | ||
Default: None (no config). | ||
s3_client (Optional[boto3.client]): s3 client to use. Default: None. | ||
sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object, | ||
used for SageMaker interactions. Default: Session in region associated with boto3 session. | ||
""" | ||
|
||
self._region = region | ||
|
@@ -121,6 +128,7 @@ def __init__( | |
if s3_client_config | ||
else boto3.client("s3", region_name=self._region) | ||
) | ||
self._sagemaker_session = sagemaker_session | ||
|
||
def set_region(self, region: str) -> None: | ||
"""Set region for cache. Clears cache after new region is set.""" | ||
|
@@ -340,32 +348,34 @@ def _retrieval_function( | |
formatted_content=model_specs | ||
) | ||
if data_type == HubContentType.MODEL: | ||
info = get_info_from_hub_resource_arn( | ||
hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn( | ||
id_info | ||
) | ||
hub = CuratedHub(hub_name=info.hub_name, region=info.region) | ||
hub_content = hub.describe_model( | ||
model_name=info.hub_content_name, model_version=info.hub_content_version | ||
hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( | ||
hub_name=hub_name, | ||
hub_content_name=model_name, | ||
hub_content_version=model_version, | ||
hub_content_type=data_type | ||
) | ||
|
||
model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should do one or the other, and stick with that throughout the session module. Personally I'd prefer returning the class rather than a dictionary There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did make all data classes for HubAPI responses. But it seems that having those as return types in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we keep the constructor for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @whittech1 To clarify, are you pointing out the typecase in
and instead it should be:
? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am saying that the constructor for
|
||
|
||
utils.emit_logs_based_on_model_specs( | ||
hub_content.content_document, | ||
model_specs, | ||
self.get_region(), | ||
self._s3_client | ||
) | ||
model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True) | ||
return JumpStartCachedContentValue( | ||
formatted_content=model_specs | ||
) | ||
if data_type == HubContentType.HUB: | ||
info = get_info_from_hub_resource_arn( | ||
id_info | ||
) | ||
hub = CuratedHub(hub_name=info.hub_name, region=info.region) | ||
hub_info = hub.describe() | ||
return JumpStartCachedContentValue(formatted_content=hub_info) | ||
if data_type == HubType.HUB: | ||
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can u make the return type a data class in the next PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clarifying question: Did you mean something like this?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, exactly! |
||
response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name) | ||
hub_description = DescribeHubResponse(response) | ||
return JumpStartCachedContentValue(formatted_content=DescribeHubResponse(hub_description)) | ||
raise ValueError( | ||
f"Bad value for key '{key}': must be in", | ||
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubContentType.HUB, HubContentType.MODEL]}" | ||
f"Bad value for key '{key}': must be in ", | ||
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubType.HUB, HubContentType.MODEL]}" | ||
) | ||
|
||
def get_manifest(self) -> List[JumpStartModelHeader]: | ||
|
@@ -490,7 +500,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]: | |
hub_arn (str): Arn for the Hub to get info for | ||
""" | ||
|
||
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubContentType.HUB, hub_arn)) | ||
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubType.HUB, hub_arn)) | ||
return details.formatted_content | ||
|
||
def clear(self) -> None: | ||
|
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,12 +14,13 @@ | |
from __future__ import absolute_import | ||
import re | ||
from typing import Optional | ||
from sagemaker.jumpstart import constants | ||
|
||
from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo | ||
from sagemaker.jumpstart.types import HubContentType | ||
from sagemaker.session import Session | ||
from sagemaker.utils import aws_partition | ||
from sagemaker.jumpstart.types import ( | ||
HubContentType, | ||
HubArnExtractedInfo, | ||
) | ||
from sagemaker.jumpstart import constants | ||
|
||
|
||
def get_info_from_hub_resource_arn( | ||
|
@@ -109,3 +110,45 @@ def generate_hub_arn_for_init_kwargs( | |
else: | ||
hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session) | ||
return hub_arn | ||
|
||
|
||
def generate_default_hub_bucket_name( | ||
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> str: | ||
"""Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions. | ||
|
||
Returns: | ||
str: The name of the default bucket. If the name was not explicitly specified through | ||
the Session or sagemaker_config, the bucket will take the form: | ||
``sagemaker-hubs-{region}-{AWS account ID}``. | ||
""" | ||
|
||
region: str = sagemaker_session.boto_region_name | ||
account_id: str = sagemaker_session.account_id() | ||
|
||
# TODO: Validate and fast fail | ||
|
||
return f"sagemaker-hubs-{region}-{account_id}" | ||
|
||
|
||
def create_hub_bucket_if_it_does_not_exist( | ||
bucket_name: Optional[str] = None, | ||
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> str: | ||
"""Creates the default SageMaker Hub bucket if it does not exist. | ||
|
||
Returns: | ||
str: The name of the default bucket. Takes the form: | ||
``sagemaker-hubs-{region}-{AWS account ID}``. | ||
""" | ||
|
||
region: str = sagemaker_session.boto_region_name | ||
if bucket_name is None: | ||
bucket_name: str = generate_default_hub_bucket_name(sagemaker_session) | ||
|
||
sagemaker_session._create_s3_bucket_if_it_does_not_exist( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you may want to do an ownership check, in case someone snipes the bucket There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 there is also the question of how we would work around this if the bucket is already taken. Perhaps we say this is very unlikely, but we should at least emit a warn in that case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. aye aye it seems that the base function does have the ownership check. Wanted to try reusing existing functions but I think we just need to write up our version so we can check ownership. |
||
bucket_name=bucket_name, | ||
region=region, | ||
) | ||
|
||
return bucket_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we need an existence check on the session variable since it could be
None
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had
self._sagemaker_session
to default toDEFAULT_JUMPSTART_SAGEMAKER_SESSION
unless it changed... I will take a look in the next PR.