10
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
+ """This module defines the JumpStartModelsCache class."""
14
+ from __future__ import absolute_import
13
15
import datetime
14
16
from typing import List , Optional
17
+ import json
18
+ import boto3
19
+ import semantic_version
15
20
from sagemaker .jumpstart .types import (
16
21
JumpStartCachedS3ContentKey ,
17
22
JumpStartCachedS3ContentValue ,
18
23
JumpStartModelHeader ,
19
24
JumpStartModelSpecs ,
20
- JumpStartModelSpecs ,
21
25
JumpStartS3FileType ,
22
26
JumpStartVersionedModelId ,
23
27
)
24
28
from sagemaker .jumpstart import utils
25
29
from sagemaker .utilities .cache import LRUCache
26
- import boto3
27
- import json
28
- import semantic_version
29
-
30
30
31
31
DEFAULT_REGION_NAME = boto3 .session .Session ().region_name
32
32
41
41
42
42
class JumpStartModelsCache :
43
43
"""Class that implements a cache for JumpStart models manifests and specs.
44
+
44
45
The manifest and specs associated with JumpStart models provide the information necessary
45
46
for launching JumpStart models from the SageMaker SDK.
46
47
"""
@@ -62,15 +63,16 @@ def __init__(
62
63
Args:
63
64
region (Optional[str]): AWS region to associate with cache. Default: region associated
64
65
with botocore session.
65
- max_s3_cache_items (Optional[int]): Maximum number of files to store in s3 cache. Default: 20.
66
- s3_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold items in s3
67
- cache before invalidation. Default: 6 hours.
66
+ max_s3_cache_items (Optional[int]): Maximum number of files to store in s3 cache.
67
+ Default: 20.
68
+ s3_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold items in
69
+ s3 cache before invalidation. Default: 6 hours.
68
70
max_semantic_version_cache_items (Optional[int]): Maximum number of files to store in
69
71
semantic version cache. Default: 20.
70
- semantic_version_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold
71
- items in semantic version cache before invalidation. Default: 6 hours.
72
- bucket (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted content
73
- bucket for region.
72
+ semantic_version_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to
73
+ hold items in semantic version cache before invalidation. Default: 6 hours.
74
+ bucket (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted
75
+ content bucket for region.
74
76
"""
75
77
76
78
self ._region = region
@@ -120,15 +122,16 @@ def get_bucket(self) -> None:
120
122
return self ._bucket
121
123
122
124
def _get_manifest_key_from_model_id_semantic_version (
123
- self , key : JumpStartVersionedModelId , value : Optional [JumpStartVersionedModelId ]
125
+ self ,
126
+ key : JumpStartVersionedModelId ,
127
+ value : Optional [JumpStartVersionedModelId ], # pylint: disable=W0613
124
128
) -> JumpStartVersionedModelId :
125
- """Return model id and version in manifest that matches semantic version/id
126
- from customer request.
129
+ """Return model id and version in manifest that matches semantic version/id.
127
130
128
131
Args:
129
132
key (JumpStartVersionedModelId): Key for which to fetch versioned model id.
130
- value (Optional[JumpStartVersionedModelId]): Unused variable for current value of old cached
131
- model id/version.
133
+ value (Optional[JumpStartVersionedModelId]): Unused variable for current value of
134
+ old cached model id/version.
132
135
133
136
Raises:
134
137
KeyError: If the semantic version is not found in the manifest.
@@ -158,42 +161,42 @@ def _get_manifest_key_from_model_id_semantic_version(
158
161
sm_compatible_model_version = spec .select (versions_compatible_with_sagemaker )
159
162
if sm_compatible_model_version is not None :
160
163
return JumpStartVersionedModelId (model_id , str (sm_compatible_model_version ))
161
- else :
162
- versions_incompatible_with_sagemaker = [
163
- semantic_version .Version (header .version )
164
+
165
+ versions_incompatible_with_sagemaker = [
166
+ semantic_version .Version (header .version )
167
+ for _ , header in manifest .items ()
168
+ if header .model_id == model_id
169
+ ]
170
+ sm_incompatible_model_version = spec .select (versions_incompatible_with_sagemaker )
171
+ if sm_incompatible_model_version is not None :
172
+ model_version_to_use_incompatible_with_sagemaker = str (sm_incompatible_model_version )
173
+ sm_version_to_use = [
174
+ header .min_version
164
175
for _ , header in manifest .items ()
165
176
if header .model_id == model_id
177
+ and header .version == model_version_to_use_incompatible_with_sagemaker
166
178
]
167
- sm_incompatible_model_version = spec .select (versions_incompatible_with_sagemaker )
168
- if sm_incompatible_model_version is not None :
169
- model_version_to_use_incompatible_with_sagemaker = str (
170
- sm_incompatible_model_version
171
- )
172
- sm_version_to_use = [
173
- header .min_version
174
- for _ , header in manifest .items ()
175
- if header .model_id == model_id
176
- and header .version == model_version_to_use_incompatible_with_sagemaker
177
- ]
178
- assert len (sm_version_to_use ) == 1
179
- sm_version_to_use = sm_version_to_use [0 ]
180
-
181
- error_msg = (
182
- f"Unable to find model manifest for { model_id } with version { version } compatible with your SageMaker version ({ sm_version } ). "
183
- f"Consider upgrading your SageMaker library to at least version { sm_version_to_use } so you can use version "
184
- f"{ model_version_to_use_incompatible_with_sagemaker } of { model_id } ."
185
- )
186
- raise KeyError (error_msg )
187
- else :
188
- error_msg = f"Unable to find model manifest for { model_id } with version { version } "
189
- raise KeyError (error_msg )
179
+ assert len (sm_version_to_use ) == 1
180
+ sm_version_to_use = sm_version_to_use [0 ]
181
+
182
+ error_msg = (
183
+ f"Unable to find model manifest for { model_id } with version { version } "
184
+ f"compatible with your SageMaker version ({ sm_version } ). "
185
+ f"Consider upgrading your SageMaker library to at least version "
186
+ f"{ sm_version_to_use } so you can use version "
187
+ f"{ model_version_to_use_incompatible_with_sagemaker } of { model_id } ."
188
+ )
189
+ raise KeyError (error_msg )
190
+ error_msg = f"Unable to find model manifest for { model_id } with version { version } "
191
+ raise KeyError (error_msg )
190
192
191
193
def _get_file_from_s3 (
192
194
self ,
193
195
key : JumpStartCachedS3ContentKey ,
194
196
value : Optional [JumpStartCachedS3ContentValue ],
195
197
) -> JumpStartCachedS3ContentValue :
196
198
"""Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``.
199
+
197
200
If a manifest file is being fetched, we only download the object if the md5 hash in
198
201
``head_object`` does not match the current md5 hash for the stored value. This prevents
199
202
unnecessarily downloading the full manifest when it hasn't changed.
@@ -228,18 +231,18 @@ def _get_file_from_s3(
228
231
raise RuntimeError (f"Bad value for key: { key } " )
229
232
230
233
def get_header (
231
- self , model_id : str , semantic_version : Optional [str ] = None
234
+ self , model_id : str , semantic_version_str : Optional [str ] = None
232
235
) -> List [JumpStartModelHeader ]:
233
236
"""Return list of headers for a given JumpStart model id and semantic version.
234
237
235
238
Args:
236
239
model_id (str): model id for which to get a header.
237
- semantic_version (Optional[str]): The semantic version for which to get a header.
238
- If None, the highest compatible version is returned.
240
+ semantic_version_str (Optional[str]): The semantic version for which to get a
241
+ header. If None, the highest compatible version is returned.
239
242
"""
240
243
241
244
versioned_model_id = self ._model_id_semantic_version_manifest_key_cache .get (
242
- JumpStartVersionedModelId (model_id , semantic_version )
245
+ JumpStartVersionedModelId (model_id , semantic_version_str )
243
246
)
244
247
manifest = self ._s3_cache .get (
245
248
JumpStartCachedS3ContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
@@ -258,16 +261,17 @@ def get_header(
258
261
return self .get_header (model_id , semantic_version )
259
262
260
263
def get_specs (
261
- self , model_id : str , semantic_version : Optional [str ] = None
264
+ self , model_id : str , semantic_version_str : Optional [str ] = None
262
265
) -> JumpStartModelSpecs :
263
266
"""Return specs for a given JumpStart model id and semantic version.
264
267
265
268
Args:
266
269
model_id (str): model id for which to get specs.
267
- semantic_version (Optional[str]): The semantic version for which to get specs.
268
- If None, the highest compatible version is returned.
270
+ semantic_version_str (Optional[str]): The semantic version for which to get
271
+ specs. If None, the highest compatible version is returned.
269
272
"""
270
- header = self .get_header (model_id , semantic_version )
273
+
274
+ header = self .get_header (model_id , semantic_version_str )
271
275
spec_key = header .spec_key
272
276
return self ._s3_cache .get (
273
277
JumpStartCachedS3ContentKey (JumpStartS3FileType .SPECS , spec_key )
0 commit comments