Skip to content

feat: jumpstart default payloads #4149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions src/sagemaker/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import print_function, absolute_import

import abc
from typing import Any, Tuple
from typing import Any, Optional, Tuple, Union

from sagemaker.deprecations import (
deprecated_class,
Expand All @@ -32,6 +32,9 @@
StreamDeserializer,
StringDeserializer,
)
from sagemaker.jumpstart.payload_utils import PayloadSerializer
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
from sagemaker.model_monitor import (
DataCaptureConfig,
DefaultModelMonitor,
Expand Down Expand Up @@ -201,20 +204,44 @@ def _create_request_args(
custom_attributes=None,
):
"""Placeholder docstring"""

jumpstart_serialized_data: Optional[Union[str, bytes]] = None
jumpstart_accept: Optional[str] = None
jumpstart_content_type: Optional[str] = None

if isinstance(data, JumpStartSerializablePayload):
s3_client = self.sagemaker_session.s3_client
region = self.sagemaker_session._region_name
bucket = get_jumpstart_content_bucket(region)

jumpstart_serialized_data = PayloadSerializer(
bucket=bucket, region=region, s3_client=s3_client
).serialize(data)
jumpstart_content_type = data.content_type
jumpstart_accept = data.accept

args = dict(initial_args) if initial_args else {}

if "EndpointName" not in args:
args["EndpointName"] = self.endpoint_name

if "ContentType" not in args:
args["ContentType"] = (
self.content_type
if isinstance(self.content_type, str)
else ", ".join(self.content_type)
)
if isinstance(data, JumpStartSerializablePayload) and jumpstart_content_type:
args["ContentType"] = jumpstart_content_type
else:
args["ContentType"] = (
self.content_type
if isinstance(self.content_type, str)
else ", ".join(self.content_type)
)

if "Accept" not in args:
args["Accept"] = self.accept if isinstance(self.accept, str) else ", ".join(self.accept)
if isinstance(data, JumpStartSerializablePayload) and jumpstart_accept:
args["Accept"] = jumpstart_accept
else:
args["Accept"] = (
self.accept if isinstance(self.accept, str) else ", ".join(self.accept)
)

if target_model:
args["TargetModel"] = target_model
Expand All @@ -228,7 +255,11 @@ def _create_request_args(
if custom_attributes:
args["CustomAttributes"] = custom_attributes

data = self.serializer.serialize(data)
data = (
jumpstart_serialized_data
if isinstance(data, JumpStartSerializablePayload) and jumpstart_serialized_data
else self.serializer.serialize(data)
)

args["Body"] = data
return args
Expand Down
83 changes: 83 additions & 0 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# language governing permissions and limitations under the License.
"""This module contains accessors related to SageMaker JumpStart."""
from __future__ import absolute_import
import functools
from typing import Any, Dict, List, Optional
import boto3

Expand All @@ -37,6 +38,88 @@ def get_sagemaker_version() -> str:
return SageMakerSettings._parsed_sagemaker_version


class JumpStartS3PayloadAccessor(object):
"""Static class for storing and retrieving S3 payload artifacts."""

MAX_CACHE_SIZE_BYTES = int(100 * 1e6)
MAX_PAYLOAD_SIZE_BYTES = int(6 * 1e6)

CACHE_SIZE = MAX_CACHE_SIZE_BYTES // MAX_PAYLOAD_SIZE_BYTES

@staticmethod
def clear_cache() -> None:
"""Clears LRU caches associated with S3 client and retrieved objects."""

JumpStartS3PayloadAccessor._get_default_s3_client.cache_clear()
JumpStartS3PayloadAccessor.get_object_cached.cache_clear()

@staticmethod
@functools.lru_cache()
def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3.client:
"""Returns default S3 client associated with the region.

Result is cached so multiple clients in memory are not created.
"""
return boto3.client("s3", region_name=region)

@staticmethod
@functools.lru_cache(maxsize=CACHE_SIZE)
def get_object_cached(
bucket: str,
key: str,
region: str = JUMPSTART_DEFAULT_REGION_NAME,
s3_client: Optional[boto3.client] = None,
) -> bytes:
"""Returns S3 object located at the bucket and key.

Requests are cached so that the same S3 request is never made more
than once, unless a different region or client is used.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hum, not sure about this, you are effectively caching objects in memory aren't you?

Could you:
(a) determine the max memory you are willing to use for such cache?
(b) add a head object with a size limits for such objects
(c) derive the max number of items in the @lru_cache(max_items) from round((a)/(b)) ?

Copy link
Member Author

@evakravi evakravi Oct 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The maximum memory is very system dependent and the payloads will come in all different sizes. How about we expose a function that clears the cache JumpStartS3Accessor.clear_cache()? This can call JumpStartS3Accessor.get_object_cached.cache_clear() under the hood. See: https://stackoverflow.com/questions/37653784/how-do-i-use-cache-clear-on-python-functools-lru-cache

return JumpStartS3PayloadAccessor.get_object(
bucket=bucket, key=key, region=region, s3_client=s3_client
)

@staticmethod
def _get_object_size_bytes(
bucket: str,
key: str,
region: str = JUMPSTART_DEFAULT_REGION_NAME,
s3_client: Optional[boto3.client] = None,
) -> bytes:
"""Returns size in bytes of S3 object using S3.HeadObject operation."""
if s3_client is None:
s3_client = JumpStartS3PayloadAccessor._get_default_s3_client(region)

return s3_client.head_object(Bucket=bucket, Key=key)["ContentLength"]

@staticmethod
def get_object(
bucket: str,
key: str,
region: str = JUMPSTART_DEFAULT_REGION_NAME,
s3_client: Optional[boto3.client] = None,
) -> bytes:
"""Returns S3 object located at the bucket and key.

Raises:
ValueError: The object size is too large.
"""
if s3_client is None:
s3_client = JumpStartS3PayloadAccessor._get_default_s3_client(region)

object_size_bytes = JumpStartS3PayloadAccessor._get_object_size_bytes(
bucket=bucket, key=key, region=region, s3_client=s3_client
)
if object_size_bytes > JumpStartS3PayloadAccessor.MAX_PAYLOAD_SIZE_BYTES:
raise ValueError(
f"s3://{bucket}/{key} has size of {object_size_bytes} bytes, "
"which exceeds maximum allowed size of "
f"{JumpStartS3PayloadAccessor.MAX_PAYLOAD_SIZE_BYTES} bytes."
)

return s3_client.get_object(Bucket=bucket, Key=key)["Body"].read()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line will buffer the whole object in memory? Is that acceptable and shouldn't you build in safeguards?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of safeguards are you referring to?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for posterity as you addressed above: inference file size mainly.



class JumpStartModelsAccessor(object):
"""Static class for storing the JumpStart models cache."""

Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@
_retrieve_model_package_arn,
_retrieve_model_package_model_artifact_s3_uri,
)
from sagemaker.jumpstart.artifacts.payloads import ( # noqa: F401
_retrieve_example_payloads,
)
85 changes: 85 additions & 0 deletions src/sagemaker/jumpstart/artifacts/payloads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains functions to obtain JumpStart model payloads."""
from __future__ import absolute_import
from copy import deepcopy
from typing import Dict, Optional
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
)
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from sagemaker.jumpstart.utils import (
verify_model_region_and_return_specs,
)
from sagemaker.session import Session


def _retrieve_example_payloads(
model_id: str,
model_version: str,
region: Optional[str],
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> Optional[Dict[str, JumpStartSerializablePayload]]:
"""Returns example payloads.

Args:
model_id (str): JumpStart model ID of the JumpStart model for which to
get example payloads.
model_version (str): Version of the JumpStart model for which to retrieve the
example payloads.
region (Optional[str]): Region for which to retrieve the
example payloads.
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities. (Default: False).
tolerate_deprecated_model (bool): True if deprecated versions of model
specifications should be tolerated (exception not raised). If False, raises
an exception if the version of the model is deprecated. (Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
Returns:
Optional[Dict[str, JumpStartSerializablePayload]]: dictionary mapping payload aliases
to the serializable payload object.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
scope=JumpStartScriptScope.INFERENCE,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)

default_payloads = model_specs.default_payloads

if default_payloads:
for payload in default_payloads.values():
payload.accept = getattr(
payload, "accept", model_specs.predictor_specs.default_accept_type
)

return deepcopy(default_payloads) if default_payloads else None
42 changes: 42 additions & 0 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re

from typing import Dict, List, Optional, Union
from sagemaker import payloads
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.base_deserializers import BaseDeserializer
from sagemaker.base_serializers import BaseSerializer
Expand All @@ -28,6 +29,7 @@
get_deploy_kwargs,
get_init_kwargs,
)
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from sagemaker.jumpstart.utils import is_valid_model_id
from sagemaker.utils import stringify_object
from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model
Expand Down Expand Up @@ -312,6 +314,46 @@ def _is_valid_model_id_hook():

super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())

def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]:
"""Returns all example payloads associated with the model.

Raises:
NotImplementedError: If the scope is not supported.
ValueError: If the combination of arguments specified is not supported.
VulnerableJumpStartModelError: If any of the dependencies required by the script have
known security vulnerabilities.
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
return payloads.retrieve_all_examples(
model_id=self.model_id,
model_version=self.model_version,
region=self.region,
tolerate_deprecated_model=self.tolerate_deprecated_model,
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
sagemaker_session=self.sagemaker_session,
)

def retrieve_example_payload(self) -> JumpStartSerializablePayload:
"""Returns the example payload associated with the model.

Payload can be directly used with the `sagemaker.predictor.Predictor.predict(...)` function.

Raises:
NotImplementedError: If the scope is not supported.
ValueError: If the combination of arguments specified is not supported.
VulnerableJumpStartModelError: If any of the dependencies required by the script have
known security vulnerabilities.
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
return payloads.retrieve_example(
model_id=self.model_id,
model_version=self.model_version,
region=self.region,
tolerate_deprecated_model=self.tolerate_deprecated_model,
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
sagemaker_session=self.sagemaker_session,
)

def _create_sagemaker_model(
self,
instance_type=None,
Expand Down
Loading