Skip to content

Commit 3712df6

Browse files
committed
feat: client cache for jumpstart models
1 parent 7268e82 commit 3712df6

File tree

14 files changed

+1663
-0
lines changed

14 files changed

+1663
-0
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def read_version():
4444
"packaging>=20.0",
4545
"pandas",
4646
"pathos",
47+
"semantic-version",
4748
]
4849

4950
# Specific use case dependencies

src/sagemaker/jumpstart/__init__.py

Whitespace-only changes.

src/sagemaker/jumpstart/cache.py

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import datetime
14+
from typing import List, Optional
15+
from sagemaker.jumpstart.types import (
16+
JumpStartCachedS3ContentKey,
17+
JumpStartCachedS3ContentValue,
18+
JumpStartModelHeader,
19+
JumpStartModelSpecs,
20+
JumpStartModelSpecs,
21+
JumpStartS3FileType,
22+
JumpStartVersionedModelId,
23+
)
24+
from sagemaker.jumpstart import utils
25+
from sagemaker.utilities.cache import LRUCache
26+
import boto3
27+
import json
28+
import semantic_version
29+
30+
31+
DEFAULT_REGION_NAME = boto3.session.Session().region_name
32+
33+
DEFAULT_MAX_S3_CACHE_ITEMS = 20
34+
DEFAULT_S3_CACHE_EXPIRATION_TIME = datetime.timedelta(hours=6)
35+
36+
DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS = 20
37+
DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_TIME = datetime.timedelta(hours=6)
38+
39+
DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
40+
41+
42+
class JumpStartModelsCache:
43+
"""Class that implements a cache for JumpStart models manifests and specs.
44+
The manifest and specs associated with JumpStart models provide the information necessary
45+
for launching JumpStart models from the SageMaker SDK.
46+
"""
47+
48+
def __init__(
49+
self,
50+
region: Optional[str] = DEFAULT_REGION_NAME,
51+
max_s3_cache_items: Optional[int] = DEFAULT_MAX_S3_CACHE_ITEMS,
52+
s3_cache_expiration_time: Optional[datetime.timedelta] = DEFAULT_S3_CACHE_EXPIRATION_TIME,
53+
max_semantic_version_cache_items: Optional[int] = DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
54+
semantic_version_cache_expiration_time: Optional[
55+
datetime.timedelta
56+
] = DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_TIME,
57+
manifest_file_s3_key: Optional[str] = DEFAULT_MANIFEST_FILE_S3_KEY,
58+
bucket: Optional[str] = None,
59+
) -> None:
60+
"""Initialize a ``JumpStartModelsCache`` instance.
61+
62+
Args:
63+
region (Optional[str]): AWS region to associate with cache. Default: region associated
64+
with botocore session.
65+
max_s3_cache_items (Optional[int]): Maximum number of files to store in s3 cache. Default: 20.
66+
s3_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold items in s3
67+
cache before invalidation. Default: 6 hours.
68+
max_semantic_version_cache_items (Optional[int]): Maximum number of files to store in
69+
semantic version cache. Default: 20.
70+
semantic_version_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold
71+
items in semantic version cache before invalidation. Default: 6 hours.
72+
bucket (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted content
73+
bucket for region.
74+
"""
75+
76+
self._region = region
77+
self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
78+
max_cache_items=max_s3_cache_items,
79+
expiration_time=s3_cache_expiration_time,
80+
retrieval_function=self._get_file_from_s3,
81+
)
82+
self._model_id_semantic_version_manifest_key_cache = LRUCache[
83+
JumpStartVersionedModelId, JumpStartVersionedModelId
84+
](
85+
max_cache_items=max_semantic_version_cache_items,
86+
expiration_time=semantic_version_cache_expiration_time,
87+
retrieval_function=self._get_manifest_key_from_model_id_semantic_version,
88+
)
89+
self._manifest_file_s3_key = manifest_file_s3_key
90+
self._bucket = (
91+
utils.get_jumpstart_content_bucket(self._region) if bucket is None else bucket
92+
)
93+
self._has_retried_cache_refresh = False
94+
95+
def set_region(self, region: str) -> None:
96+
"""Set region for cache. Clears cache after new region is set."""
97+
self._region = region
98+
self.clear()
99+
100+
def get_region(self) -> str:
101+
"""Return region for cache."""
102+
return self._region
103+
104+
def set_manifest_file_s3_key(self, key: str) -> None:
105+
"""Set manifest file s3 key. Clears cache after new key is set."""
106+
self._manifest_file_s3_key = key
107+
self.clear()
108+
109+
def get_manifest_file_s3_key(self) -> None:
110+
"""Return manifest file s3 key for cache."""
111+
return self._manifest_file_s3_key
112+
113+
def set_bucket(self, bucket: str) -> None:
114+
"""Set s3 bucket used for cache."""
115+
self._bucket = bucket
116+
self.clear()
117+
118+
def get_bucket(self) -> None:
119+
"""Return bucket used for cache."""
120+
return self._bucket
121+
122+
def _get_manifest_key_from_model_id_semantic_version(
123+
self, key: JumpStartVersionedModelId, value: Optional[JumpStartVersionedModelId]
124+
) -> JumpStartVersionedModelId:
125+
"""Return model id and version in manifest that matches semantic version/id
126+
from customer request.
127+
128+
Args:
129+
key (JumpStartVersionedModelId): Key for which to fetch versioned model id.
130+
value (Optional[JumpStartVersionedModelId]): Unused variable for current value of old cached
131+
model id/version.
132+
133+
Raises:
134+
KeyError: If the semantic version is not found in the manifest.
135+
"""
136+
137+
model_id, version = key.model_id, key.version
138+
139+
manifest = self._s3_cache.get(
140+
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
141+
).formatted_file_content
142+
143+
sm_version = utils.get_sagemaker_version()
144+
145+
versions_compatible_with_sagemaker = [
146+
semantic_version.Version(header.version)
147+
for _, header in manifest.items()
148+
if header.model_id == model_id
149+
and semantic_version.Version(header.min_version) <= semantic_version.Version(sm_version)
150+
]
151+
152+
spec = (
153+
semantic_version.SimpleSpec("*")
154+
if version is None
155+
else semantic_version.SimpleSpec(version)
156+
)
157+
158+
sm_compatible_model_version = spec.select(versions_compatible_with_sagemaker)
159+
if sm_compatible_model_version is not None:
160+
return JumpStartVersionedModelId(model_id, str(sm_compatible_model_version))
161+
else:
162+
versions_incompatible_with_sagemaker = [
163+
semantic_version.Version(header.version)
164+
for _, header in manifest.items()
165+
if header.model_id == model_id
166+
]
167+
sm_incompatible_model_version = spec.select(versions_incompatible_with_sagemaker)
168+
if sm_incompatible_model_version is not None:
169+
model_version_to_use_incompatible_with_sagemaker = str(
170+
sm_incompatible_model_version
171+
)
172+
sm_version_to_use = [
173+
header.min_version
174+
for _, header in manifest.items()
175+
if header.model_id == model_id
176+
and header.version == model_version_to_use_incompatible_with_sagemaker
177+
]
178+
assert len(sm_version_to_use) == 1
179+
sm_version_to_use = sm_version_to_use[0]
180+
181+
error_msg = (
182+
f"Unable to find model manifest for {model_id} with version {version} compatible with your SageMaker version ({sm_version}). "
183+
f"Consider upgrading your SageMaker library to at least version {sm_version_to_use} so you can use version "
184+
f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}."
185+
)
186+
raise KeyError(error_msg)
187+
else:
188+
error_msg = f"Unable to find model manifest for {model_id} with version {version}"
189+
raise KeyError(error_msg)
190+
191+
def _get_file_from_s3(
192+
self,
193+
key: JumpStartCachedS3ContentKey,
194+
value: Optional[JumpStartCachedS3ContentValue],
195+
) -> JumpStartCachedS3ContentValue:
196+
"""Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``.
197+
If a manifest file is being fetched, we only download the object if the md5 hash in
198+
``head_object`` does not match the current md5 hash for the stored value. This prevents
199+
unnecessarily downloading the full manifest when it hasn't changed.
200+
201+
Args:
202+
key (JumpStartCachedS3ContentKey): key for which to fetch s3 content.
203+
value (Optional[JumpStartVersionedModelId]): Current value of old cached
204+
s3 content. This is used for the manifest file, so that it is only
205+
downloaded when its content changes.
206+
"""
207+
208+
file_type, s3_key = key.file_type, key.s3_key
209+
210+
s3_client = boto3.client("s3", region_name=self._region)
211+
212+
if file_type == JumpStartS3FileType.MANIFEST:
213+
etag = s3_client.head_object(Bucket=self._bucket, Key=s3_key)["ETag"]
214+
if value is not None and etag == value.md5_hash:
215+
return value
216+
response = s3_client.get_object(Bucket=self._bucket, Key=s3_key)
217+
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
218+
return JumpStartCachedS3ContentValue(
219+
formatted_file_content=utils.get_formatted_manifest(formatted_body),
220+
md5_hash=etag,
221+
)
222+
if file_type == JumpStartS3FileType.SPECS:
223+
response = s3_client.get_object(Bucket=self._bucket, Key=s3_key)
224+
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
225+
return JumpStartCachedS3ContentValue(
226+
formatted_file_content=JumpStartModelSpecs(formatted_body)
227+
)
228+
raise RuntimeError(f"Bad value for key: {key}")
229+
230+
def get_header(
231+
self, model_id: str, semantic_version: Optional[str] = None
232+
) -> List[JumpStartModelHeader]:
233+
"""Return list of headers for a given JumpStart model id and semantic version.
234+
235+
Args:
236+
model_id (str): model id for which to get a header.
237+
semantic_version (Optional[str]): The semantic version for which to get a header.
238+
If None, the highest compatible version is returned.
239+
"""
240+
241+
versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get(
242+
JumpStartVersionedModelId(model_id, semantic_version)
243+
)
244+
manifest = self._s3_cache.get(
245+
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
246+
).formatted_file_content
247+
try:
248+
header = manifest[versioned_model_id]
249+
if self._has_retried_cache_refresh:
250+
self._has_retried_cache_refresh = False
251+
return header
252+
except KeyError:
253+
if self._has_retried_cache_refresh:
254+
self._has_retried_cache_refresh = False
255+
raise
256+
self.clear()
257+
self._has_retried_cache_refresh = True
258+
return self.get_header(model_id, semantic_version)
259+
260+
def get_specs(
261+
self, model_id: str, semantic_version: Optional[str] = None
262+
) -> JumpStartModelSpecs:
263+
"""Return specs for a given JumpStart model id and semantic version.
264+
265+
Args:
266+
model_id (str): model id for which to get specs.
267+
semantic_version (Optional[str]): The semantic version for which to get specs.
268+
If None, the highest compatible version is returned.
269+
"""
270+
header = self.get_header(model_id, semantic_version)
271+
spec_key = header.spec_key
272+
return self._s3_cache.get(
273+
JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key)
274+
).formatted_file_content
275+
276+
def clear(self) -> None:
277+
"""Clears the model id/version and s3 cache and resets ``_has_retried_cache_refresh``."""
278+
self._s3_cache.clear()
279+
self._model_id_semantic_version_manifest_key_cache.clear()
280+
self._has_retried_cache_refresh = False

src/sagemaker/jumpstart/constants.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from typing import Set
14+
from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo
15+
16+
17+
LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set()
18+
19+
REGION_NAME_TO_LAUNCHED_REGION_DICT = {region.region_name: region for region in LAUNCHED_REGIONS}
20+
REGION_NAME_SET = {region.region_name for region in LAUNCHED_REGIONS}

0 commit comments

Comments
 (0)