Skip to content

feature: client cache for jumpstart models #2756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[run]
concurrency = threading
concurrency = thread
omit = sagemaker/tests/*
timid = True
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def read_version():
"packaging>=20.0",
"pandas",
"pathos",
"semantic-version",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to add extra package? Whats the use of this?

Copy link
Member

@mufaddal-rohawala mufaddal-rohawala Dec 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@evakravi There are reservations to add new dependencies in SDK due to some inherrant constraints. Is there a possibility to use some other alternatives currently existing in SDK for versioning? Can we use packaging.version here.

]

# Specific use case dependencies
Expand Down
Empty file.
327 changes: 327 additions & 0 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module defines the JumpStartModelsCache class."""
from __future__ import absolute_import
import datetime
from typing import List, Optional
import json
import boto3
import botocore
import semantic_version
from sagemaker.jumpstart.constants import (
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.parameters import (
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
)
from sagemaker.jumpstart.types import (
JumpStartCachedS3ContentKey,
JumpStartCachedS3ContentValue,
JumpStartModelHeader,
JumpStartModelSpecs,
JumpStartS3FileType,
JumpStartVersionedModelId,
)
from sagemaker.jumpstart import utils
from sagemaker.utilities.cache import LRUCache


class JumpStartModelsCache:
"""Class that implements a cache for JumpStart models manifests and specs.

The manifest and specs associated with JumpStart models provide the information necessary
for launching JumpStart models from the SageMaker SDK.
"""

def __init__(
self,
region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME,
max_s3_cache_items: Optional[int] = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
s3_cache_expiration_horizon: Optional[
datetime.timedelta
] = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
max_semantic_version_cache_items: Optional[
int
] = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
semantic_version_cache_expiration_horizon: Optional[
datetime.timedelta
] = JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
manifest_file_s3_key: Optional[str] = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
s3_bucket_name: Optional[str] = None,
s3_client_config: Optional[botocore.config.Config] = None,
) -> None:
"""Initialize a ``JumpStartModelsCache`` instance.

Args:
region (Optional[str]): AWS region to associate with cache. Default: region associated
with boto3 session.
max_s3_cache_items (Optional[int]): Maximum number of items to store in s3 cache.
Default: 20.
s3_cache_expiration_horizon (Optional[datetime.timedelta]): Maximum time to hold
items in s3 cache before invalidation. Default: 6 hours.
max_semantic_version_cache_items (Optional[int]): Maximum number of items to store in
semantic version cache. Default: 20.
semantic_version_cache_expiration_horizon (Optional[datetime.timedelta]):
Maximum time to hold items in semantic version cache before invalidation.
Default: 6 hours.
s3_bucket_name (Optional[str]): S3 bucket to associate with cache.
Default: JumpStart-hosted content bucket for region.
s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
Default: None (no config).
"""

self._region = region
self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
max_cache_items=max_s3_cache_items,
expiration_horizon=s3_cache_expiration_horizon,
retrieval_function=self._get_file_from_s3,
)
self._model_id_semantic_version_manifest_key_cache = LRUCache[
JumpStartVersionedModelId, JumpStartVersionedModelId
](
max_cache_items=max_semantic_version_cache_items,
expiration_horizon=semantic_version_cache_expiration_horizon,
retrieval_function=self._get_manifest_key_from_model_id_semantic_version,
)
self._manifest_file_s3_key = manifest_file_s3_key
self.s3_bucket_name = (
utils.get_jumpstart_content_bucket(self._region)
if s3_bucket_name is None
else s3_bucket_name
)
self._s3_client = (
boto3.client("s3", region_name=self._region, config=s3_client_config)
if s3_client_config
else boto3.client("s3", region_name=self._region)
)

def set_region(self, region: str) -> None:
"""Set region for cache. Clears cache after new region is set."""
if region != self._region:
self._region = region
self.clear()

def get_region(self) -> str:
"""Return region for cache."""
return self._region

def set_manifest_file_s3_key(self, key: str) -> None:
"""Set manifest file s3 key. Clears cache after new key is set."""
if key != self._manifest_file_s3_key:
self._manifest_file_s3_key = key
self.clear()

def get_manifest_file_s3_key(self) -> None:
"""Return manifest file s3 key for cache."""
return self._manifest_file_s3_key

def set_s3_bucket_name(self, s3_bucket_name: str) -> None:
"""Set s3 bucket used for cache."""
if s3_bucket_name != self.s3_bucket_name:
self.s3_bucket_name = s3_bucket_name
self.clear()

def get_bucket(self) -> None:
"""Return bucket used for cache."""
return self.s3_bucket_name

def _get_manifest_key_from_model_id_semantic_version(
self,
key: JumpStartVersionedModelId,
value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613
) -> JumpStartVersionedModelId:
"""Return model id and version in manifest that matches semantic version/id.

Uses ``semantic_version`` to perform version comparison. The highest model version
matching the semantic version is used, which is compatible with the SageMaker
version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a line to explain the SDK version compatibility check.

Args:
key (JumpStartVersionedModelId): Key for which to fetch versioned model id.
value (Optional[JumpStartVersionedModelId]): Unused variable for current value of
old cached model id/version.

Raises:
KeyError: If the semantic version is not found in the manifest, or is found but
the SageMaker version needs to be upgraded in order for the model to be used.
"""

model_id, version = key.model_id, key.version

manifest = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
).formatted_content

sm_version = utils.get_sagemaker_version()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: should this be constant defined once?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's tricky because of circular imports. Let's leave it like this for now. I don't think this is an expensive operation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it makes sense, you can't import sagemaker within sagemaker. I agree that should be cheap enough. You could also cache it at runtime:

sagemaker_version = ""

def get_sagemaker_version() -> str:
   if not sagemaker_version:
         sagemaker_version = parse_sagemaker_version()
   return sagemaker_version

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. This'll go in the next commit.


versions_compatible_with_sagemaker = [
semantic_version.Version(header.version)
for header in manifest.values()
if header.model_id == model_id
and semantic_version.Version(header.min_version) <= semantic_version.Version(sm_version)
]

spec = (
semantic_version.SimpleSpec("*")
if version is None
else semantic_version.SimpleSpec(version)
)

sm_compatible_model_version = spec.select(versions_compatible_with_sagemaker)
if sm_compatible_model_version is not None:
return JumpStartVersionedModelId(model_id, str(sm_compatible_model_version))

versions_incompatible_with_sagemaker = [
semantic_version.Version(header.version)
for header in manifest.values()
if header.model_id == model_id
]
sm_incompatible_model_version = spec.select(versions_incompatible_with_sagemaker)
if sm_incompatible_model_version is not None:
model_version_to_use_incompatible_with_sagemaker = str(sm_incompatible_model_version)
sm_version_to_use = [
header.min_version
for header in manifest.values()
if header.model_id == model_id
and header.version == model_version_to_use_incompatible_with_sagemaker
]
if len(sm_version_to_use) != 1:
# ``manifest`` dict should already enforce this
raise RuntimeError("Found more than one incompatible SageMaker version to use.")
sm_version_to_use = sm_version_to_use[0]

error_msg = (
f"Unable to find model manifest for {model_id} with version {version} "
f"compatible with your SageMaker version ({sm_version}). "
f"Consider upgrading your SageMaker library to at least version "
f"{sm_version_to_use} so you can use version "
f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}."
)
raise KeyError(error_msg)
error_msg = f"Unable to find model manifest for {model_id} with version {version}."
raise KeyError(error_msg)

def _get_file_from_s3(
self,
key: JumpStartCachedS3ContentKey,
value: Optional[JumpStartCachedS3ContentValue],
) -> JumpStartCachedS3ContentValue:
"""Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``.

If a manifest file is being fetched, we only download the object if the md5 hash in
``head_object`` does not match the current md5 hash for the stored value. This prevents
unnecessarily downloading the full manifest when it hasn't changed.

Args:
key (JumpStartCachedS3ContentKey): key for which to fetch s3 content.
value (Optional[JumpStartVersionedModelId]): Current value of old cached
s3 content. This is used for the manifest file, so that it is only
downloaded when its content changes.
"""

file_type, s3_key = key.file_type, key.s3_key

if file_type == JumpStartS3FileType.MANIFEST:
if value is not None:
etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"]
if etag == value.md5_hash:
return value
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several issues here:

  • if value is None, you are making an unnecessary http call
  • if value is not None but the etag are different, you risk caching the wrong etag

Does this work?

if value is not None:
    etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"]
    if  etag == value.md5_hash:
         return value

response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
etag = response["ETag"]
return JumpStartCachedS3ContentValue(
    formatted_file_content=utils.get_formatted_manifest(formatted_body),
    md5_hash=etag,
)

While we are at it, can you unit test this behavior please?

etag = response["ETag"]
return JumpStartCachedS3ContentValue(
formatted_content=utils.get_formatted_manifest(formatted_body),
md5_hash=etag,
)
if file_type == JumpStartS3FileType.SPECS:
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
return JumpStartCachedS3ContentValue(
formatted_content=JumpStartModelSpecs(formatted_body)
)
raise ValueError(
f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}"
)

def get_manifest(self) -> List[JumpStartModelHeader]:
"""Return entire JumpStart models manifest."""

return self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
).formatted_content.values()

def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader:
"""Return header for a given JumpStart model id and semantic version.

Args:
model_id (str): model id for which to get a header.
semantic_version_str (str): The semantic version for which to get a
header.
"""

return self._get_header_impl(model_id, semantic_version_str=semantic_version_str)

def _get_header_impl(
self,
model_id: str,
semantic_version_str: str,
attempt: Optional[int] = 0,
) -> JumpStartModelHeader:
"""Lower-level function to return header.

Allows a single retry if the cache is old.

Args:
model_id (str): model id for which to get a header.
semantic_version_str (str): The semantic version for which to get a
header.
attempt (Optional[int]): attempt number at retrieving a header.
"""

versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get(
JumpStartVersionedModelId(model_id, semantic_version_str)
)
manifest = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
).formatted_content
try:
return manifest[versioned_model_id]
except KeyError:
if attempt > 0:
raise
self.clear()
return self._get_header_impl(model_id, semantic_version_str, attempt + 1)

def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelSpecs:
"""Return specs for a given JumpStart model id and semantic version.

Args:
model_id (str): model id for which to get specs.
semantic_version_str (str): The semantic version for which to get
specs.
"""

header = self.get_header(model_id, semantic_version_str)
spec_key = header.spec_key
return self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key)
).formatted_content

def clear(self) -> None:
"""Clears the model id/version and s3 cache."""
self._s3_cache.clear()
self._model_id_semantic_version_manifest_key_cache.clear()
29 changes: 29 additions & 0 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module stores constants related to SageMaker JumpStart."""
from __future__ import absolute_import
from typing import Set
import boto3
from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo


JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set()

JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT = {
region.region_name: region for region in JUMPSTART_LAUNCHED_REGIONS
}
JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS}

JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name

JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
20 changes: 20 additions & 0 deletions src/sagemaker/jumpstart/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module stores parameters related to SageMaker JumpStart."""
from __future__ import absolute_import
import datetime

JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS = 20
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS = 20
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON = datetime.timedelta(hours=6)
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON = datetime.timedelta(hours=6)
Loading