13
13
"""This module defines the JumpStartModelsCache class."""
14
14
from __future__ import absolute_import
15
15
import datetime
16
- from typing import List , Optional
16
+ from typing import Optional
17
17
import json
18
18
import boto3
19
+ import botocore
19
20
import semantic_version
21
+ from sagemaker .jumpstart .constants import (
22
+ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ,
23
+ JUMPSTART_DEFAULT_REGION_NAME ,
24
+ )
25
+ from sagemaker .jumpstart .parameters import (
26
+ JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS ,
27
+ JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS ,
28
+ JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON ,
29
+ JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON ,
30
+ )
20
31
from sagemaker .jumpstart .types import (
21
32
JumpStartCachedS3ContentKey ,
22
33
JumpStartCachedS3ContentValue ,
28
39
from sagemaker .jumpstart import utils
29
40
from sagemaker .utilities .cache import LRUCache
30
41
31
- DEFAULT_REGION_NAME = boto3 .session .Session ().region_name
32
-
33
- DEFAULT_MAX_S3_CACHE_ITEMS = 20
34
- DEFAULT_S3_CACHE_EXPIRATION_TIME = datetime .timedelta (hours = 6 )
35
-
36
- DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS = 20
37
- DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_TIME = datetime .timedelta (hours = 6 )
38
-
39
- DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
40
-
41
42
42
43
class JumpStartModelsCache :
43
44
"""Class that implements a cache for JumpStart models manifests and specs.
@@ -48,78 +49,95 @@ class JumpStartModelsCache:
48
49
49
50
def __init__ (
50
51
self ,
51
- region : Optional [str ] = DEFAULT_REGION_NAME ,
52
- max_s3_cache_items : Optional [int ] = DEFAULT_MAX_S3_CACHE_ITEMS ,
53
- s3_cache_expiration_time : Optional [datetime .timedelta ] = DEFAULT_S3_CACHE_EXPIRATION_TIME ,
54
- max_semantic_version_cache_items : Optional [int ] = DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS ,
55
- semantic_version_cache_expiration_time : Optional [
52
+ region : Optional [str ] = JUMPSTART_DEFAULT_REGION_NAME ,
53
+ max_s3_cache_items : Optional [int ] = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS ,
54
+ s3_cache_expiration_horizon : Optional [
56
55
datetime .timedelta
57
- ] = DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_TIME ,
58
- manifest_file_s3_key : Optional [str ] = DEFAULT_MANIFEST_FILE_S3_KEY ,
59
- bucket : Optional [str ] = None ,
56
+ ] = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON ,
57
+ max_semantic_version_cache_items : Optional [
58
+ int
59
+ ] = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS ,
60
+ semantic_version_cache_expiration_horizon : Optional [
61
+ datetime .timedelta
62
+ ] = JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON ,
63
+ manifest_file_s3_key : Optional [str ] = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ,
64
+ s3_bucket_name : Optional [str ] = None ,
65
+ s3_client_config : Optional [botocore .config .Config ] = None ,
60
66
) -> None :
61
67
"""Initialize a ``JumpStartModelsCache`` instance.
62
68
63
69
Args:
64
70
region (Optional[str]): AWS region to associate with cache. Default: region associated
65
- with botocore session.
66
- max_s3_cache_items (Optional[int]): Maximum number of files to store in s3 cache.
71
+ with boto3 session.
72
+ max_s3_cache_items (Optional[int]): Maximum number of items to store in s3 cache.
67
73
Default: 20.
68
- s3_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold items in
69
- s3 cache before invalidation. Default: 6 hours.
70
- max_semantic_version_cache_items (Optional[int]): Maximum number of files to store in
74
+ s3_cache_expiration_horizon (Optional[datetime.timedelta]): Maximum time to hold
75
+ items in s3 cache before invalidation. Default: 6 hours.
76
+ max_semantic_version_cache_items (Optional[int]): Maximum number of items to store in
71
77
semantic version cache. Default: 20.
72
- semantic_version_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to
73
- hold items in semantic version cache before invalidation. Default: 6 hours.
74
- bucket (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted
75
- content bucket for region.
78
+ semantic_version_cache_expiration_horizon (Optional[datetime.timedelta]):
79
+ Maximum time to hold items in semantic version cache before invalidation.
80
+ Default: 6 hours.
81
+ s3_bucket_name (Optional[str]): S3 bucket to associate with cache.
82
+ Default: JumpStart-hosted content bucket for region.
83
+ s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
84
+ Default: None (no config).
76
85
"""
77
86
78
87
self ._region = region
79
88
self ._s3_cache = LRUCache [JumpStartCachedS3ContentKey , JumpStartCachedS3ContentValue ](
80
89
max_cache_items = max_s3_cache_items ,
81
- expiration_time = s3_cache_expiration_time ,
90
+ expiration_horizon = s3_cache_expiration_horizon ,
82
91
retrieval_function = self ._get_file_from_s3 ,
83
92
)
84
93
self ._model_id_semantic_version_manifest_key_cache = LRUCache [
85
94
JumpStartVersionedModelId , JumpStartVersionedModelId
86
95
](
87
96
max_cache_items = max_semantic_version_cache_items ,
88
- expiration_time = semantic_version_cache_expiration_time ,
97
+ expiration_horizon = semantic_version_cache_expiration_horizon ,
89
98
retrieval_function = self ._get_manifest_key_from_model_id_semantic_version ,
90
99
)
91
100
self ._manifest_file_s3_key = manifest_file_s3_key
92
- self ._bucket = (
93
- utils .get_jumpstart_content_bucket (self ._region ) if bucket is None else bucket
101
+ self .s3_bucket_name = (
102
+ utils .get_jumpstart_content_bucket (self ._region )
103
+ if s3_bucket_name is None
104
+ else s3_bucket_name
105
+ )
106
+ self ._s3_client = (
107
+ boto3 .client ("s3" , region_name = self ._region , config = s3_client_config )
108
+ if s3_client_config
109
+ else boto3 .client ("s3" , region_name = self ._region )
94
110
)
95
- self ._has_retried_cache_refresh = False
96
111
97
112
def set_region (self , region : str ) -> None :
98
113
"""Set region for cache. Clears cache after new region is set."""
99
- self ._region = region
100
- self .clear ()
114
+ if region != self ._region :
115
+ self ._region = region
116
+ self .clear ()
101
117
102
118
def get_region (self ) -> str :
103
119
"""Return region for cache."""
104
120
return self ._region
105
121
106
122
def set_manifest_file_s3_key (self , key : str ) -> None :
107
123
"""Set manifest file s3 key. Clears cache after new key is set."""
108
- self ._manifest_file_s3_key = key
109
- self .clear ()
124
+ if key != self ._manifest_file_s3_key :
125
+ self ._manifest_file_s3_key = key
126
+ self .clear ()
110
127
111
128
def get_manifest_file_s3_key (self ) -> None :
112
129
"""Return manifest file s3 key for cache."""
113
130
return self ._manifest_file_s3_key
114
131
115
- def set_bucket (self , bucket : str ) -> None :
132
+ def set_s3_bucket_name (self , s3_bucket_name : str ) -> None :
116
133
"""Set s3 bucket used for cache."""
117
- self ._bucket = bucket
118
- self .clear ()
134
+ if s3_bucket_name != self .s3_bucket_name :
135
+ self .s3_bucket_name = s3_bucket_name
136
+ self .clear ()
119
137
120
138
def get_bucket (self ) -> None :
121
139
"""Return bucket used for cache."""
122
- return self ._bucket
140
+ return self .s3_bucket_name
123
141
124
142
def _get_manifest_key_from_model_id_semantic_version (
125
143
self ,
@@ -128,13 +146,18 @@ def _get_manifest_key_from_model_id_semantic_version(
128
146
) -> JumpStartVersionedModelId :
129
147
"""Return model id and version in manifest that matches semantic version/id.
130
148
149
+ Uses ``semantic_version`` to perform version comparison. The highest model version
150
+ matching the semantic version is used, which is compatible with the SageMaker
151
+ version.
152
+
131
153
Args:
132
154
key (JumpStartVersionedModelId): Key for which to fetch versioned model id.
133
155
value (Optional[JumpStartVersionedModelId]): Unused variable for current value of
134
156
old cached model id/version.
135
157
136
158
Raises:
137
- KeyError: If the semantic version is not found in the manifest.
159
+ KeyError: If the semantic version is not found in the manifest, or is found but
160
+ the SageMaker version needs to be upgraded in order for the model to be used.
138
161
"""
139
162
140
163
model_id , version = key .model_id , key .version
@@ -147,7 +170,7 @@ def _get_manifest_key_from_model_id_semantic_version(
147
170
148
171
versions_compatible_with_sagemaker = [
149
172
semantic_version .Version (header .version )
150
- for _ , header in manifest .items ()
173
+ for header in manifest .values ()
151
174
if header .model_id == model_id
152
175
and semantic_version .Version (header .min_version ) <= semantic_version .Version (sm_version )
153
176
]
@@ -164,19 +187,19 @@ def _get_manifest_key_from_model_id_semantic_version(
164
187
165
188
versions_incompatible_with_sagemaker = [
166
189
semantic_version .Version (header .version )
167
- for _ , header in manifest .items ()
190
+ for header in manifest .values ()
168
191
if header .model_id == model_id
169
192
]
170
193
sm_incompatible_model_version = spec .select (versions_incompatible_with_sagemaker )
171
194
if sm_incompatible_model_version is not None :
172
195
model_version_to_use_incompatible_with_sagemaker = str (sm_incompatible_model_version )
173
196
sm_version_to_use = [
174
197
header .min_version
175
- for _ , header in manifest .items ()
198
+ for header in manifest .values ()
176
199
if header .model_id == model_id
177
200
and header .version == model_version_to_use_incompatible_with_sagemaker
178
201
]
179
- assert len (sm_version_to_use ) == 1
202
+ assert len (sm_version_to_use ) == 1 # ``manifest`` dict should already enforce this
180
203
sm_version_to_use = sm_version_to_use [0 ]
181
204
182
205
error_msg = (
@@ -187,7 +210,7 @@ def _get_manifest_key_from_model_id_semantic_version(
187
210
f"{ model_version_to_use_incompatible_with_sagemaker } of { model_id } ."
188
211
)
189
212
raise KeyError (error_msg )
190
- error_msg = f"Unable to find model manifest for { model_id } with version { version } "
213
+ error_msg = f"Unable to find model manifest for { model_id } with version { version } . "
191
214
raise KeyError (error_msg )
192
215
193
216
def _get_file_from_s3 (
@@ -210,33 +233,49 @@ def _get_file_from_s3(
210
233
211
234
file_type , s3_key = key .file_type , key .s3_key
212
235
213
- s3_client = boto3 .client ("s3" , region_name = self ._region )
214
-
215
236
if file_type == JumpStartS3FileType .MANIFEST :
216
- etag = s3_client . head_object (Bucket = self ._bucket , Key = s3_key )["ETag" ]
237
+ etag = self . _s3_client . head_object (Bucket = self .s3_bucket_name , Key = s3_key )["ETag" ]
217
238
if value is not None and etag == value .md5_hash :
218
239
return value
219
- response = s3_client . get_object (Bucket = self ._bucket , Key = s3_key )
240
+ response = self . _s3_client . get_object (Bucket = self .s3_bucket_name , Key = s3_key )
220
241
formatted_body = json .loads (response ["Body" ].read ().decode ("utf-8" ))
221
242
return JumpStartCachedS3ContentValue (
222
243
formatted_file_content = utils .get_formatted_manifest (formatted_body ),
223
244
md5_hash = etag ,
224
245
)
225
246
if file_type == JumpStartS3FileType .SPECS :
226
- response = s3_client . get_object (Bucket = self ._bucket , Key = s3_key )
247
+ response = self . _s3_client . get_object (Bucket = self .s3_bucket_name , Key = s3_key )
227
248
formatted_body = json .loads (response ["Body" ].read ().decode ("utf-8" ))
228
249
return JumpStartCachedS3ContentValue (
229
250
formatted_file_content = JumpStartModelSpecs (formatted_body )
230
251
)
231
- raise RuntimeError (f"Bad value for key: { key } " )
252
+ raise ValueError (
253
+ f"Bad value for key '{ key } ': must be in { [JumpStartS3FileType .MANIFEST , JumpStartS3FileType .SPECS ]} "
254
+ )
232
255
233
256
def get_header (
234
257
self , model_id : str , semantic_version_str : Optional [str ] = None
235
- ) -> List [JumpStartModelHeader ]:
236
- """Return list of headers for a given JumpStart model id and semantic version.
258
+ ) -> JumpStartModelHeader :
259
+ """Return header for a given JumpStart model id and semantic version.
260
+
261
+ Args:
262
+ model_id (str): model id for which to get a header.
263
+ semantic_version_str (Optional[str]): The semantic version for which to get a
264
+ header. If None, the highest compatible version is returned.
265
+ """
266
+
267
+ return self ._get_header_impl (model_id , 0 , semantic_version_str )
268
+
269
+ def _get_header_impl (
270
+ self , model_id : str , attempt : int , semantic_version_str : Optional [str ] = None
271
+ ) -> JumpStartModelHeader :
272
+ """Lower-level function to return header.
273
+
274
+ Allows a single retry if the cache is old.
237
275
238
276
Args:
239
277
model_id (str): model id for which to get a header.
278
+ attempt (int): attempt number at retrieving a header.
240
279
semantic_version_str (Optional[str]): The semantic version for which to get a
241
280
header. If None, the highest compatible version is returned.
242
281
"""
@@ -248,17 +287,12 @@ def get_header(
248
287
JumpStartCachedS3ContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
249
288
).formatted_file_content
250
289
try :
251
- header = manifest [versioned_model_id ]
252
- if self ._has_retried_cache_refresh :
253
- self ._has_retried_cache_refresh = False
254
- return header
290
+ return manifest [versioned_model_id ]
255
291
except KeyError :
256
- if self ._has_retried_cache_refresh :
257
- self ._has_retried_cache_refresh = False
292
+ if attempt > 0 :
258
293
raise
259
294
self .clear ()
260
- self ._has_retried_cache_refresh = True
261
- return self .get_header (model_id , semantic_version )
295
+ return self ._get_header_impl (model_id , attempt + 1 , semantic_version_str )
262
296
263
297
def get_specs (
264
298
self , model_id : str , semantic_version_str : Optional [str ] = None
@@ -278,7 +312,6 @@ def get_specs(
278
312
).formatted_file_content
279
313
280
314
def clear (self ) -> None :
281
- """Clears the model id/version and s3 cache and resets ``_has_retried_cache_refresh`` ."""
315
+ """Clears the model id/version and s3 cache."""
282
316
self ._s3_cache .clear ()
283
317
self ._model_id_semantic_version_manifest_key_cache .clear ()
284
- self ._has_retried_cache_refresh = False
0 commit comments