-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: jumpstart default payloads #4149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
cba9034
6082e59
bf13065
9750dc5
5dc4aac
ca20d87
b2a6374
214e458
0a61992
aa3163e
9bfb962
89980b9
47be5b7
fc900e7
f7bb0aa
777b24a
c869f30
2aecf36
e5a7058
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
# language governing permissions and limitations under the License. | ||
"""This module contains accessors related to SageMaker JumpStart.""" | ||
from __future__ import absolute_import | ||
import functools | ||
from typing import Any, Dict, List, Optional | ||
import boto3 | ||
|
||
|
@@ -37,6 +38,88 @@ def get_sagemaker_version() -> str: | |
return SageMakerSettings._parsed_sagemaker_version | ||
|
||
|
||
class JumpStartS3PayloadAccessor(object): | ||
"""Static class for storing and retrieving S3 payload artifacts.""" | ||
|
||
MAX_CACHE_SIZE_BYTES = int(100 * 1e6) | ||
MAX_PAYLOAD_SIZE_BYTES = int(6 * 1e6) | ||
|
||
CACHE_SIZE = MAX_CACHE_SIZE_BYTES // MAX_PAYLOAD_SIZE_BYTES | ||
|
||
@staticmethod | ||
def clear_cache() -> None: | ||
"""Clears LRU caches associated with S3 client and retrieved objects.""" | ||
|
||
JumpStartS3PayloadAccessor._get_default_s3_client.cache_clear() | ||
JumpStartS3PayloadAccessor.get_object_cached.cache_clear() | ||
|
||
@staticmethod | ||
@functools.lru_cache() | ||
def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3.client: | ||
"""Returns default S3 client associated with the region. | ||
|
||
Result is cached so multiple clients in memory are not created. | ||
""" | ||
return boto3.client("s3", region_name=region) | ||
|
||
@staticmethod | ||
@functools.lru_cache(maxsize=CACHE_SIZE) | ||
def get_object_cached( | ||
bucket: str, | ||
key: str, | ||
region: str = JUMPSTART_DEFAULT_REGION_NAME, | ||
s3_client: Optional[boto3.client] = None, | ||
) -> bytes: | ||
"""Returns S3 object located at the bucket and key. | ||
|
||
Requests are cached so that the same S3 request is never made more | ||
than once, unless a different region or client is used. | ||
""" | ||
return JumpStartS3PayloadAccessor.get_object( | ||
bucket=bucket, key=key, region=region, s3_client=s3_client | ||
) | ||
|
||
@staticmethod | ||
def _get_object_size_bytes( | ||
bucket: str, | ||
key: str, | ||
region: str = JUMPSTART_DEFAULT_REGION_NAME, | ||
s3_client: Optional[boto3.client] = None, | ||
) -> bytes: | ||
"""Returns size in bytes of S3 object using S3.HeadObject operation.""" | ||
if s3_client is None: | ||
s3_client = JumpStartS3PayloadAccessor._get_default_s3_client(region) | ||
|
||
return s3_client.head_object(Bucket=bucket, Key=key)["ContentLength"] | ||
|
||
@staticmethod | ||
def get_object( | ||
bucket: str, | ||
key: str, | ||
region: str = JUMPSTART_DEFAULT_REGION_NAME, | ||
s3_client: Optional[boto3.client] = None, | ||
) -> bytes: | ||
"""Returns S3 object located at the bucket and key. | ||
|
||
Raises: | ||
ValueError: The object size is too large. | ||
""" | ||
if s3_client is None: | ||
s3_client = JumpStartS3PayloadAccessor._get_default_s3_client(region) | ||
|
||
object_size_bytes = JumpStartS3PayloadAccessor._get_object_size_bytes( | ||
bucket=bucket, key=key, region=region, s3_client=s3_client | ||
) | ||
if object_size_bytes > JumpStartS3PayloadAccessor.MAX_PAYLOAD_SIZE_BYTES: | ||
raise ValueError( | ||
f"s3://{bucket}/{key} has size of {object_size_bytes} bytes, " | ||
"which exceeds maximum allowed size of " | ||
f"{JumpStartS3PayloadAccessor.MAX_PAYLOAD_SIZE_BYTES} bytes." | ||
) | ||
|
||
return s3_client.get_object(Bucket=bucket, Key=key)["Body"].read() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this line will buffer the whole object in memory? Is that acceptable and shouldn't you build in safeguards? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What kind of safeguards are you referring to? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for posterity as you addressed above: inference file size mainly. |
||
|
||
|
||
class JumpStartModelsAccessor(object): | ||
"""Static class for storing the JumpStart models cache.""" | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains functions to obtain JumpStart model payloads.""" | ||
from __future__ import absolute_import | ||
from copy import deepcopy | ||
from typing import Dict, Optional | ||
from sagemaker.jumpstart.constants import ( | ||
DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
JUMPSTART_DEFAULT_REGION_NAME, | ||
) | ||
from sagemaker.jumpstart.enums import ( | ||
JumpStartScriptScope, | ||
) | ||
from sagemaker.jumpstart.types import JumpStartSerializablePayload | ||
from sagemaker.jumpstart.utils import ( | ||
verify_model_region_and_return_specs, | ||
) | ||
from sagemaker.session import Session | ||
|
||
|
||
def _retrieve_default_payloads( | ||
model_id: str, | ||
model_version: str, | ||
region: Optional[str], | ||
tolerate_vulnerable_model: bool = False, | ||
tolerate_deprecated_model: bool = False, | ||
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> Optional[Dict[str, JumpStartSerializablePayload]]: | ||
"""Returns default payloads. | ||
|
||
Args: | ||
model_id (str): JumpStart model ID of the JumpStart model for which to | ||
get default payloads. | ||
model_version (str): Version of the JumpStart model for which to retrieve the | ||
default resource name. | ||
region (Optional[str]): Region for which to retrieve the | ||
default resource name. | ||
tolerate_vulnerable_model (bool): True if vulnerable versions of model | ||
specifications should be tolerated (exception not raised). If False, raises an | ||
exception if the script used by this version of the model has dependencies with known | ||
security vulnerabilities. (Default: False). | ||
tolerate_deprecated_model (bool): True if deprecated versions of model | ||
evakravi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
specifications should be tolerated (exception not raised). If False, raises | ||
an exception if the version of the model is deprecated. (Default: False). | ||
sagemaker_session (sagemaker.session.Session): A SageMaker Session | ||
object, used for SageMaker interactions. If not | ||
specified, one is created using the default AWS configuration | ||
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). | ||
Returns: | ||
str: the default payload. | ||
""" | ||
|
||
if region is None: | ||
region = JUMPSTART_DEFAULT_REGION_NAME | ||
|
||
model_specs = verify_model_region_and_return_specs( | ||
model_id=model_id, | ||
version=model_version, | ||
scope=JumpStartScriptScope.INFERENCE, | ||
region=region, | ||
tolerate_vulnerable_model=tolerate_vulnerable_model, | ||
tolerate_deprecated_model=tolerate_deprecated_model, | ||
sagemaker_session=sagemaker_session, | ||
) | ||
|
||
default_payloads = model_specs.default_payloads | ||
|
||
if default_payloads: | ||
for payload in default_payloads.values(): | ||
payload.accept = getattr( | ||
payload, "accept", model_specs.predictor_specs.default_accept_type | ||
) | ||
|
||
return deepcopy(default_payloads) if default_payloads else None |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
import re | ||
|
||
from typing import Dict, List, Optional, Union | ||
from sagemaker import payloads | ||
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig | ||
from sagemaker.base_deserializers import BaseDeserializer | ||
from sagemaker.base_serializers import BaseSerializer | ||
|
@@ -28,6 +29,7 @@ | |
get_deploy_kwargs, | ||
get_init_kwargs, | ||
) | ||
from sagemaker.jumpstart.types import JumpStartSerializablePayload | ||
from sagemaker.jumpstart.utils import is_valid_model_id | ||
from sagemaker.utils import stringify_object | ||
from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model | ||
|
@@ -312,6 +314,27 @@ def _is_valid_model_id_hook(): | |
|
||
super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) | ||
|
||
def retrieve_default_payload(self) -> JumpStartSerializablePayload: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry I missed this: this instance method needs be renamed Please also add a |
||
"""Returns the default payload associated with the model. | ||
|
||
Payload can be directly used with the `sagemaker.predictor.Predictor.predict(...)` function. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is customer facing: please add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if Related question: do you need a way to propagate the |
||
|
||
Raises: | ||
NotImplementedError: If the scope is not supported. | ||
ValueError: If the combination of arguments specified is not supported. | ||
VulnerableJumpStartModelError: If any of the dependencies required by the script have | ||
known security vulnerabilities. | ||
DeprecatedJumpStartModelError: If the version of the model is deprecated. | ||
""" | ||
return payloads.retrieve_example( | ||
model_id=self.model_id, | ||
model_version=self.model_version, | ||
region=self.region, | ||
tolerate_deprecated_model=self.tolerate_deprecated_model, | ||
tolerate_vulnerable_model=self.tolerate_vulnerable_model, | ||
sagemaker_session=self.sagemaker_session, | ||
) | ||
|
||
def _create_sagemaker_model( | ||
self, | ||
instance_type=None, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hum, not sure about this, you are effectively caching objects in memory aren't you?
Could you:
(a) determine the max memory you are willing to use for such cache?
(b) add a head object with a size limits for such objects
(c) derive the max number of items in the
@lru_cache(max_items)
from round((a)/(b)) ?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The maximum memory is very system dependent and the payloads will come in all different sizes. How about we expose a function that clears the cache
JumpStartS3Accessor.clear_cache()
? This can callJumpStartS3Accessor.get_object_cached.cache_clear()
under the hood. See: https://stackoverflow.com/questions/37653784/how-do-i-use-cache-clear-on-python-functools-lru-cache