Skip to content

Commit 7511bb0

Browse files
committed
change: add test for cache memory usage, add get_manifest function for jumpstart cache
1 parent d7ff8f8 commit 7511bb0

File tree

4 files changed

+84
-42
lines changed

4 files changed

+84
-42
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""This module defines the JumpStartModelsCache class."""
1414
from __future__ import absolute_import
1515
import datetime
16-
from typing import Optional
16+
from typing import List, Optional
1717
import json
1818
import boto3
1919
import botocore
@@ -253,6 +253,13 @@ def _get_file_from_s3(
253253
f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}"
254254
)
255255

256+
def get_manifest(self) -> List[JumpStartModelHeader]:
257+
"""Return entire JumpStart models manifest."""
258+
259+
return self._s3_cache.get(
260+
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
261+
).formatted_file_content.values()
262+
256263
def get_header(
257264
self, model_id: str, semantic_version_str: Optional[str] = None
258265
) -> JumpStartModelHeader:

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,13 @@ def __eq__(self, other: Any) -> bool:
3232

3333
if not isinstance(other, type(self)):
3434
return False
35+
if getattr(other, "__slots__", None) is None:
36+
return False
37+
if self.__slots__ != other.__slots__:
38+
return False
3539
for attribute in self.__slots__:
3640
if getattr(self, attribute) != getattr(other, attribute):
3741
return False
38-
for attribute in other.__slots__:
39-
if getattr(self, attribute) != getattr(other, attribute):
40-
return False
4142
return True
4243

4344
def __hash__(self) -> int:

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,44 @@
6060
},
6161
}
6262

63+
BASE_MANIFEST = [
64+
{
65+
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
66+
"version": "1.0.0",
67+
"min_version": "2.49.0",
68+
"spec_key": "community_models_specs/tensorflow-ic-imagenet"
69+
"-inception-v3-classification-4/specs_v1.0.0.json",
70+
},
71+
{
72+
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
73+
"version": "2.0.0",
74+
"min_version": "2.49.0",
75+
"spec_key": "community_models_specs/tensorflow-ic-imagenet"
76+
"-inception-v3-classification-4/specs_v2.0.0.json",
77+
},
78+
{
79+
"model_id": "pytorch-ic-imagenet-inception-v3-classification-4",
80+
"version": "1.0.0",
81+
"min_version": "2.49.0",
82+
"spec_key": "community_models_specs/pytorch-ic-"
83+
"imagenet-inception-v3-classification-4/specs_v1.0.0.json",
84+
},
85+
{
86+
"model_id": "pytorch-ic-imagenet-inception-v3-classification-4",
87+
"version": "2.0.0",
88+
"min_version": "2.49.0",
89+
"spec_key": "community_models_specs/pytorch-ic-imagenet-"
90+
"inception-v3-classification-4/specs_v2.0.0.json",
91+
},
92+
{
93+
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
94+
"version": "3.0.0",
95+
"min_version": "4.49.0",
96+
"spec_key": "community_models_specs/tensorflow-ic-"
97+
"imagenet-inception-v3-classification-4/specs_v3.0.0.json",
98+
},
99+
]
100+
63101

64102
def get_spec_from_base_spec(model_id: str, version: str) -> JumpStartModelSpecs:
65103
spec = copy.deepcopy(BASE_SPEC)
@@ -77,45 +115,9 @@ def patched_get_file_from_s3(
77115

78116
filetype, s3_key = key.file_type, key.s3_key
79117
if filetype == JumpStartS3FileType.MANIFEST:
80-
manifest = [
81-
{
82-
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
83-
"version": "1.0.0",
84-
"min_version": "2.49.0",
85-
"spec_key": "community_models_specs/tensorflow-ic-imagenet"
86-
"-inception-v3-classification-4/specs_v1.0.0.json",
87-
},
88-
{
89-
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
90-
"version": "2.0.0",
91-
"min_version": "2.49.0",
92-
"spec_key": "community_models_specs/tensorflow-ic-imagenet"
93-
"-inception-v3-classification-4/specs_v2.0.0.json",
94-
},
95-
{
96-
"model_id": "pytorch-ic-imagenet-inception-v3-classification-4",
97-
"version": "1.0.0",
98-
"min_version": "2.49.0",
99-
"spec_key": "community_models_specs/pytorch-ic-"
100-
"imagenet-inception-v3-classification-4/specs_v1.0.0.json",
101-
},
102-
{
103-
"model_id": "pytorch-ic-imagenet-inception-v3-classification-4",
104-
"version": "2.0.0",
105-
"min_version": "2.49.0",
106-
"spec_key": "community_models_specs/pytorch-ic-imagenet-"
107-
"inception-v3-classification-4/specs_v2.0.0.json",
108-
},
109-
{
110-
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
111-
"version": "3.0.0",
112-
"min_version": "4.49.0",
113-
"spec_key": "community_models_specs/tensorflow-ic-"
114-
"imagenet-inception-v3-classification-4/specs_v3.0.0.json",
115-
},
116-
]
118+
117119
return JumpStartCachedS3ContentValue(
118-
formatted_file_content=get_formatted_manifest(manifest)
120+
formatted_file_content=get_formatted_manifest(BASE_MANIFEST)
119121
)
120122

121123
if filetype == JumpStartS3FileType.SPECS:
@@ -579,6 +581,15 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache():
579581
cache.clear.assert_called_once()
580582

581583

584+
@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3)
585+
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
586+
def test_jumpstart_get_full_manifest():
587+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
588+
raw_manifest = [header.to_json() for header in cache.get_manifest()]
589+
590+
raw_manifest == BASE_MANIFEST
591+
592+
582593
@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3)
583594
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
584595
def test_jumpstart_cache_get_specs():

tests/unit/sagemaker/utilities/test_cache.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from typing import Optional, Union
1515
from mock.mock import MagicMock, patch
1616
import pytest
17+
import pickle
18+
1719

1820
from sagemaker.utilities import cache
1921
import datetime
@@ -192,3 +194,24 @@ def test_cache_clear_and_contains():
192194
assert len(my_cache) == 0
193195
with pytest.raises(KeyError):
194196
my_cache.get(1, False)
197+
198+
199+
def test_cache_memory_usage():
200+
my_cache = cache.LRUCache[int, Union[int, str]](
201+
max_cache_items=10,
202+
expiration_horizon=datetime.timedelta(hours=1),
203+
retrieval_function=retrieval_function,
204+
)
205+
cache_size_bytes = []
206+
cache_size_bytes.append(len(pickle.dumps(my_cache)))
207+
for i in range(50):
208+
my_cache.put(i)
209+
cache_size_bytes.append(len(pickle.dumps(my_cache)))
210+
211+
max_cache_items_iter_cache_size = cache_size_bytes[10]
212+
past_capacity_iter_cache_size = cache_size_bytes[50]
213+
percent_difference_cache_size = (
214+
abs(past_capacity_iter_cache_size - max_cache_items_iter_cache_size)
215+
/ max_cache_items_iter_cache_size
216+
) * 100.0
217+
assert percent_difference_cache_size < 1

0 commit comments

Comments
 (0)