-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: jumpstart default payloads #4149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cba9034
6082e59
bf13065
9750dc5
5dc4aac
ca20d87
b2a6374
214e458
0a61992
aa3163e
9bfb962
89980b9
47be5b7
fc900e7
f7bb0aa
777b24a
c869f30
2aecf36
e5a7058
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
# language governing permissions and limitations under the License. | ||
"""This module contains accessors related to SageMaker JumpStart.""" | ||
from __future__ import absolute_import | ||
import functools | ||
from typing import Any, Dict, List, Optional | ||
import boto3 | ||
|
||
|
@@ -37,6 +38,88 @@ def get_sagemaker_version() -> str: | |
return SageMakerSettings._parsed_sagemaker_version | ||
|
||
|
||
class JumpStartS3PayloadAccessor(object): | ||
"""Static class for storing and retrieving S3 payload artifacts.""" | ||
|
||
MAX_CACHE_SIZE_BYTES = int(100 * 1e6) | ||
MAX_PAYLOAD_SIZE_BYTES = int(6 * 1e6) | ||
|
||
CACHE_SIZE = MAX_CACHE_SIZE_BYTES // MAX_PAYLOAD_SIZE_BYTES | ||
|
||
@staticmethod | ||
def clear_cache() -> None: | ||
"""Clears LRU caches associated with S3 client and retrieved objects.""" | ||
|
||
JumpStartS3PayloadAccessor._get_default_s3_client.cache_clear() | ||
JumpStartS3PayloadAccessor.get_object_cached.cache_clear() | ||
|
||
@staticmethod | ||
@functools.lru_cache() | ||
def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3.client: | ||
"""Returns default S3 client associated with the region. | ||
|
||
Result is cached so multiple clients in memory are not created. | ||
""" | ||
return boto3.client("s3", region_name=region) | ||
|
||
@staticmethod | ||
@functools.lru_cache(maxsize=CACHE_SIZE) | ||
def get_object_cached( | ||
bucket: str, | ||
key: str, | ||
region: str = JUMPSTART_DEFAULT_REGION_NAME, | ||
s3_client: Optional[boto3.client] = None, | ||
) -> bytes: | ||
"""Returns S3 object located at the bucket and key. | ||
|
||
Requests are cached so that the same S3 request is never made more | ||
than once, unless a different region or client is used. | ||
""" | ||
return JumpStartS3PayloadAccessor.get_object( | ||
bucket=bucket, key=key, region=region, s3_client=s3_client | ||
) | ||
|
||
@staticmethod | ||
def _get_object_size_bytes( | ||
bucket: str, | ||
key: str, | ||
region: str = JUMPSTART_DEFAULT_REGION_NAME, | ||
s3_client: Optional[boto3.client] = None, | ||
) -> bytes: | ||
"""Returns size in bytes of S3 object using S3.HeadObject operation.""" | ||
if s3_client is None: | ||
s3_client = JumpStartS3PayloadAccessor._get_default_s3_client(region) | ||
|
||
return s3_client.head_object(Bucket=bucket, Key=key)["ContentLength"] | ||
|
||
@staticmethod | ||
def get_object( | ||
bucket: str, | ||
key: str, | ||
region: str = JUMPSTART_DEFAULT_REGION_NAME, | ||
s3_client: Optional[boto3.client] = None, | ||
) -> bytes: | ||
"""Returns S3 object located at the bucket and key. | ||
|
||
Raises: | ||
ValueError: The object size is too large. | ||
""" | ||
if s3_client is None: | ||
s3_client = JumpStartS3PayloadAccessor._get_default_s3_client(region) | ||
|
||
object_size_bytes = JumpStartS3PayloadAccessor._get_object_size_bytes( | ||
bucket=bucket, key=key, region=region, s3_client=s3_client | ||
) | ||
if object_size_bytes > JumpStartS3PayloadAccessor.MAX_PAYLOAD_SIZE_BYTES: | ||
raise ValueError( | ||
f"s3://{bucket}/{key} has size of {object_size_bytes} bytes, " | ||
"which exceeds maximum allowed size of " | ||
f"{JumpStartS3PayloadAccessor.MAX_PAYLOAD_SIZE_BYTES} bytes." | ||
) | ||
|
||
return s3_client.get_object(Bucket=bucket, Key=key)["Body"].read() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this line will buffer the whole object in memory? Is that acceptable and shouldn't you build in safeguards? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What kind of safeguards are you referring to? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for posterity as you addressed above: inference file size mainly. |
||
|
||
|
||
class JumpStartModelsAccessor(object): | ||
"""Static class for storing the JumpStart models cache.""" | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains functions to obtain JumpStart model payloads.""" | ||
from __future__ import absolute_import | ||
from copy import deepcopy | ||
from typing import Dict, Optional | ||
from sagemaker.jumpstart.constants import ( | ||
DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
JUMPSTART_DEFAULT_REGION_NAME, | ||
) | ||
from sagemaker.jumpstart.enums import ( | ||
JumpStartScriptScope, | ||
) | ||
from sagemaker.jumpstart.types import JumpStartSerializablePayload | ||
from sagemaker.jumpstart.utils import ( | ||
verify_model_region_and_return_specs, | ||
) | ||
from sagemaker.session import Session | ||
|
||
|
||
def _retrieve_example_payloads( | ||
model_id: str, | ||
model_version: str, | ||
region: Optional[str], | ||
tolerate_vulnerable_model: bool = False, | ||
tolerate_deprecated_model: bool = False, | ||
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> Optional[Dict[str, JumpStartSerializablePayload]]: | ||
"""Returns example payloads. | ||
|
||
Args: | ||
model_id (str): JumpStart model ID of the JumpStart model for which to | ||
get example payloads. | ||
model_version (str): Version of the JumpStart model for which to retrieve the | ||
example payloads. | ||
region (Optional[str]): Region for which to retrieve the | ||
example payloads. | ||
tolerate_vulnerable_model (bool): True if vulnerable versions of model | ||
specifications should be tolerated (exception not raised). If False, raises an | ||
exception if the script used by this version of the model has dependencies with known | ||
security vulnerabilities. (Default: False). | ||
tolerate_deprecated_model (bool): True if deprecated versions of model | ||
evakravi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
specifications should be tolerated (exception not raised). If False, raises | ||
an exception if the version of the model is deprecated. (Default: False). | ||
sagemaker_session (sagemaker.session.Session): A SageMaker Session | ||
object, used for SageMaker interactions. If not | ||
specified, one is created using the default AWS configuration | ||
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). | ||
Returns: | ||
Optional[Dict[str, JumpStartSerializablePayload]]: dictionary mapping payload aliases | ||
to the serializable payload object. | ||
""" | ||
|
||
if region is None: | ||
region = JUMPSTART_DEFAULT_REGION_NAME | ||
|
||
model_specs = verify_model_region_and_return_specs( | ||
model_id=model_id, | ||
version=model_version, | ||
scope=JumpStartScriptScope.INFERENCE, | ||
region=region, | ||
tolerate_vulnerable_model=tolerate_vulnerable_model, | ||
tolerate_deprecated_model=tolerate_deprecated_model, | ||
sagemaker_session=sagemaker_session, | ||
) | ||
|
||
default_payloads = model_specs.default_payloads | ||
|
||
if default_payloads: | ||
for payload in default_payloads.values(): | ||
payload.accept = getattr( | ||
payload, "accept", model_specs.predictor_specs.default_accept_type | ||
) | ||
|
||
return deepcopy(default_payloads) if default_payloads else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hum, not sure about this, you are effectively caching objects in memory aren't you?
Could you:
(a) determine the max memory you are willing to use for such cache?
(b) add a head object with a size limits for such objects
(c) derive the max number of items in the
@lru_cache(max_items)
from round((a)/(b)) ?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The maximum memory is very system dependent and the payloads will come in all different sizes. How about we expose a function that clears the cache
JumpStartS3Accessor.clear_cache()
? This can callJumpStartS3Accessor.get_object_cached.cache_clear()
under the hood. See: https://stackoverflow.com/questions/37653784/how-do-i-use-cache-clear-on-python-functools-lru-cache