Skip to content

Commit d7ff8f8

Browse files
committed
fix: update jumpstart cache and unit tests
1 parent 37a502e commit d7ff8f8

File tree

10 files changed

+342
-165
lines changed

10 files changed

+342
-165
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 98 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,21 @@
1313
"""This module defines the JumpStartModelsCache class."""
1414
from __future__ import absolute_import
1515
import datetime
16-
from typing import List, Optional
16+
from typing import Optional
1717
import json
1818
import boto3
19+
import botocore
1920
import semantic_version
21+
from sagemaker.jumpstart.constants import (
22+
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
23+
JUMPSTART_DEFAULT_REGION_NAME,
24+
)
25+
from sagemaker.jumpstart.parameters import (
26+
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
27+
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
28+
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
29+
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
30+
)
2031
from sagemaker.jumpstart.types import (
2132
JumpStartCachedS3ContentKey,
2233
JumpStartCachedS3ContentValue,
@@ -28,16 +39,6 @@
2839
from sagemaker.jumpstart import utils
2940
from sagemaker.utilities.cache import LRUCache
3041

31-
DEFAULT_REGION_NAME = boto3.session.Session().region_name
32-
33-
DEFAULT_MAX_S3_CACHE_ITEMS = 20
34-
DEFAULT_S3_CACHE_EXPIRATION_TIME = datetime.timedelta(hours=6)
35-
36-
DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS = 20
37-
DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_TIME = datetime.timedelta(hours=6)
38-
39-
DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
40-
4142

4243
class JumpStartModelsCache:
4344
"""Class that implements a cache for JumpStart models manifests and specs.
@@ -48,78 +49,95 @@ class JumpStartModelsCache:
4849

4950
def __init__(
5051
self,
51-
region: Optional[str] = DEFAULT_REGION_NAME,
52-
max_s3_cache_items: Optional[int] = DEFAULT_MAX_S3_CACHE_ITEMS,
53-
s3_cache_expiration_time: Optional[datetime.timedelta] = DEFAULT_S3_CACHE_EXPIRATION_TIME,
54-
max_semantic_version_cache_items: Optional[int] = DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
55-
semantic_version_cache_expiration_time: Optional[
52+
region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME,
53+
max_s3_cache_items: Optional[int] = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
54+
s3_cache_expiration_horizon: Optional[
5655
datetime.timedelta
57-
] = DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_TIME,
58-
manifest_file_s3_key: Optional[str] = DEFAULT_MANIFEST_FILE_S3_KEY,
59-
bucket: Optional[str] = None,
56+
] = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
57+
max_semantic_version_cache_items: Optional[
58+
int
59+
] = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
60+
semantic_version_cache_expiration_horizon: Optional[
61+
datetime.timedelta
62+
] = JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
63+
manifest_file_s3_key: Optional[str] = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
64+
s3_bucket_name: Optional[str] = None,
65+
s3_client_config: Optional[botocore.config.Config] = None,
6066
) -> None:
6167
"""Initialize a ``JumpStartModelsCache`` instance.
6268
6369
Args:
6470
region (Optional[str]): AWS region to associate with cache. Default: region associated
65-
with botocore session.
66-
max_s3_cache_items (Optional[int]): Maximum number of files to store in s3 cache.
71+
with boto3 session.
72+
max_s3_cache_items (Optional[int]): Maximum number of items to store in s3 cache.
6773
Default: 20.
68-
s3_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold items in
69-
s3 cache before invalidation. Default: 6 hours.
70-
max_semantic_version_cache_items (Optional[int]): Maximum number of files to store in
74+
s3_cache_expiration_horizon (Optional[datetime.timedelta]): Maximum time to hold
75+
items in s3 cache before invalidation. Default: 6 hours.
76+
max_semantic_version_cache_items (Optional[int]): Maximum number of items to store in
7177
semantic version cache. Default: 20.
72-
semantic_version_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to
73-
hold items in semantic version cache before invalidation. Default: 6 hours.
74-
bucket (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted
75-
content bucket for region.
78+
semantic_version_cache_expiration_horizon (Optional[datetime.timedelta]):
79+
Maximum time to hold items in semantic version cache before invalidation.
80+
Default: 6 hours.
81+
s3_bucket_name (Optional[str]): S3 bucket to associate with cache.
82+
Default: JumpStart-hosted content bucket for region.
83+
s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
84+
Default: None (no config).
7685
"""
7786

7887
self._region = region
7988
self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
8089
max_cache_items=max_s3_cache_items,
81-
expiration_time=s3_cache_expiration_time,
90+
expiration_horizon=s3_cache_expiration_horizon,
8291
retrieval_function=self._get_file_from_s3,
8392
)
8493
self._model_id_semantic_version_manifest_key_cache = LRUCache[
8594
JumpStartVersionedModelId, JumpStartVersionedModelId
8695
](
8796
max_cache_items=max_semantic_version_cache_items,
88-
expiration_time=semantic_version_cache_expiration_time,
97+
expiration_horizon=semantic_version_cache_expiration_horizon,
8998
retrieval_function=self._get_manifest_key_from_model_id_semantic_version,
9099
)
91100
self._manifest_file_s3_key = manifest_file_s3_key
92-
self._bucket = (
93-
utils.get_jumpstart_content_bucket(self._region) if bucket is None else bucket
101+
self.s3_bucket_name = (
102+
utils.get_jumpstart_content_bucket(self._region)
103+
if s3_bucket_name is None
104+
else s3_bucket_name
105+
)
106+
self._s3_client = (
107+
boto3.client("s3", region_name=self._region, config=s3_client_config)
108+
if s3_client_config
109+
else boto3.client("s3", region_name=self._region)
94110
)
95-
self._has_retried_cache_refresh = False
96111

97112
def set_region(self, region: str) -> None:
98113
"""Set region for cache. Clears cache after new region is set."""
99-
self._region = region
100-
self.clear()
114+
if region != self._region:
115+
self._region = region
116+
self.clear()
101117

102118
def get_region(self) -> str:
103119
"""Return region for cache."""
104120
return self._region
105121

106122
def set_manifest_file_s3_key(self, key: str) -> None:
107123
"""Set manifest file s3 key. Clears cache after new key is set."""
108-
self._manifest_file_s3_key = key
109-
self.clear()
124+
if key != self._manifest_file_s3_key:
125+
self._manifest_file_s3_key = key
126+
self.clear()
110127

111128
def get_manifest_file_s3_key(self) -> None:
112129
"""Return manifest file s3 key for cache."""
113130
return self._manifest_file_s3_key
114131

115-
def set_bucket(self, bucket: str) -> None:
132+
def set_s3_bucket_name(self, s3_bucket_name: str) -> None:
116133
"""Set s3 bucket used for cache."""
117-
self._bucket = bucket
118-
self.clear()
134+
if s3_bucket_name != self.s3_bucket_name:
135+
self.s3_bucket_name = s3_bucket_name
136+
self.clear()
119137

120138
def get_bucket(self) -> None:
121139
"""Return bucket used for cache."""
122-
return self._bucket
140+
return self.s3_bucket_name
123141

124142
def _get_manifest_key_from_model_id_semantic_version(
125143
self,
@@ -128,13 +146,18 @@ def _get_manifest_key_from_model_id_semantic_version(
128146
) -> JumpStartVersionedModelId:
129147
"""Return model id and version in manifest that matches semantic version/id.
130148
149+
Uses ``semantic_version`` to perform version comparison. The highest model version
150+
matching the semantic version is used, which is compatible with the SageMaker
151+
version.
152+
131153
Args:
132154
key (JumpStartVersionedModelId): Key for which to fetch versioned model id.
133155
value (Optional[JumpStartVersionedModelId]): Unused variable for current value of
134156
old cached model id/version.
135157
136158
Raises:
137-
KeyError: If the semantic version is not found in the manifest.
159+
KeyError: If the semantic version is not found in the manifest, or is found but
160+
the SageMaker version needs to be upgraded in order for the model to be used.
138161
"""
139162

140163
model_id, version = key.model_id, key.version
@@ -147,7 +170,7 @@ def _get_manifest_key_from_model_id_semantic_version(
147170

148171
versions_compatible_with_sagemaker = [
149172
semantic_version.Version(header.version)
150-
for _, header in manifest.items()
173+
for header in manifest.values()
151174
if header.model_id == model_id
152175
and semantic_version.Version(header.min_version) <= semantic_version.Version(sm_version)
153176
]
@@ -164,19 +187,19 @@ def _get_manifest_key_from_model_id_semantic_version(
164187

165188
versions_incompatible_with_sagemaker = [
166189
semantic_version.Version(header.version)
167-
for _, header in manifest.items()
190+
for header in manifest.values()
168191
if header.model_id == model_id
169192
]
170193
sm_incompatible_model_version = spec.select(versions_incompatible_with_sagemaker)
171194
if sm_incompatible_model_version is not None:
172195
model_version_to_use_incompatible_with_sagemaker = str(sm_incompatible_model_version)
173196
sm_version_to_use = [
174197
header.min_version
175-
for _, header in manifest.items()
198+
for header in manifest.values()
176199
if header.model_id == model_id
177200
and header.version == model_version_to_use_incompatible_with_sagemaker
178201
]
179-
assert len(sm_version_to_use) == 1
202+
assert len(sm_version_to_use) == 1 # ``manifest`` dict should already enforce this
180203
sm_version_to_use = sm_version_to_use[0]
181204

182205
error_msg = (
@@ -187,7 +210,7 @@ def _get_manifest_key_from_model_id_semantic_version(
187210
f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}."
188211
)
189212
raise KeyError(error_msg)
190-
error_msg = f"Unable to find model manifest for {model_id} with version {version}"
213+
error_msg = f"Unable to find model manifest for {model_id} with version {version}."
191214
raise KeyError(error_msg)
192215

193216
def _get_file_from_s3(
@@ -210,33 +233,49 @@ def _get_file_from_s3(
210233

211234
file_type, s3_key = key.file_type, key.s3_key
212235

213-
s3_client = boto3.client("s3", region_name=self._region)
214-
215236
if file_type == JumpStartS3FileType.MANIFEST:
216-
etag = s3_client.head_object(Bucket=self._bucket, Key=s3_key)["ETag"]
237+
etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"]
217238
if value is not None and etag == value.md5_hash:
218239
return value
219-
response = s3_client.get_object(Bucket=self._bucket, Key=s3_key)
240+
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
220241
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
221242
return JumpStartCachedS3ContentValue(
222243
formatted_file_content=utils.get_formatted_manifest(formatted_body),
223244
md5_hash=etag,
224245
)
225246
if file_type == JumpStartS3FileType.SPECS:
226-
response = s3_client.get_object(Bucket=self._bucket, Key=s3_key)
247+
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
227248
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
228249
return JumpStartCachedS3ContentValue(
229250
formatted_file_content=JumpStartModelSpecs(formatted_body)
230251
)
231-
raise RuntimeError(f"Bad value for key: {key}")
252+
raise ValueError(
253+
f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}"
254+
)
232255

233256
def get_header(
234257
self, model_id: str, semantic_version_str: Optional[str] = None
235-
) -> List[JumpStartModelHeader]:
236-
"""Return list of headers for a given JumpStart model id and semantic version.
258+
) -> JumpStartModelHeader:
259+
"""Return header for a given JumpStart model id and semantic version.
260+
261+
Args:
262+
model_id (str): model id for which to get a header.
263+
semantic_version_str (Optional[str]): The semantic version for which to get a
264+
header. If None, the highest compatible version is returned.
265+
"""
266+
267+
return self._get_header_impl(model_id, 0, semantic_version_str)
268+
269+
def _get_header_impl(
270+
self, model_id: str, attempt: int, semantic_version_str: Optional[str] = None
271+
) -> JumpStartModelHeader:
272+
"""Lower-level function to return header.
273+
274+
Allows a single retry if the cache is old.
237275
238276
Args:
239277
model_id (str): model id for which to get a header.
278+
attempt (int): attempt number at retrieving a header.
240279
semantic_version_str (Optional[str]): The semantic version for which to get a
241280
header. If None, the highest compatible version is returned.
242281
"""
@@ -248,17 +287,12 @@ def get_header(
248287
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
249288
).formatted_file_content
250289
try:
251-
header = manifest[versioned_model_id]
252-
if self._has_retried_cache_refresh:
253-
self._has_retried_cache_refresh = False
254-
return header
290+
return manifest[versioned_model_id]
255291
except KeyError:
256-
if self._has_retried_cache_refresh:
257-
self._has_retried_cache_refresh = False
292+
if attempt > 0:
258293
raise
259294
self.clear()
260-
self._has_retried_cache_refresh = True
261-
return self.get_header(model_id, semantic_version)
295+
return self._get_header_impl(model_id, attempt + 1, semantic_version_str)
262296

263297
def get_specs(
264298
self, model_id: str, semantic_version_str: Optional[str] = None
@@ -278,7 +312,6 @@ def get_specs(
278312
).formatted_file_content
279313

280314
def clear(self) -> None:
281-
"""Clears the model id/version and s3 cache and resets ``_has_retried_cache_refresh``."""
315+
"""Clears the model id/version and s3 cache."""
282316
self._s3_cache.clear()
283317
self._model_id_semantic_version_manifest_key_cache.clear()
284-
self._has_retried_cache_refresh = False

src/sagemaker/jumpstart/constants.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,17 @@
1313
"""This module stores constants related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
1515
from typing import Set
16+
import boto3
1617
from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo
1718

1819

19-
LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set()
20+
JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set()
2021

21-
REGION_NAME_TO_LAUNCHED_REGION_DICT = {region.region_name: region for region in LAUNCHED_REGIONS}
22-
REGION_NAME_SET = {region.region_name for region in LAUNCHED_REGIONS}
22+
JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT = {
23+
region.region_name: region for region in JUMPSTART_LAUNCHED_REGIONS
24+
}
25+
JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS}
26+
27+
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name
28+
29+
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"

src/sagemaker/jumpstart/parameters.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module stores parameters related to SageMaker JumpStart."""
14+
from __future__ import absolute_import
15+
import datetime
16+
17+
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS = 20
18+
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS = 20
19+
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON = datetime.timedelta(hours=6)
20+
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON = datetime.timedelta(hours=6)

0 commit comments

Comments
 (0)