Skip to content

Commit 04143a6

Browse files
committed
change: remove cache memory test, improve jumpstart cache design/tests
1 parent 7511bb0 commit 04143a6

File tree

7 files changed

+126
-73
lines changed

7 files changed

+126
-73
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -234,11 +234,13 @@ def _get_file_from_s3(
234234
file_type, s3_key = key.file_type, key.s3_key
235235

236236
if file_type == JumpStartS3FileType.MANIFEST:
237-
etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"]
238-
if value is not None and etag == value.md5_hash:
239-
return value
237+
if value is not None:
238+
etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"]
239+
if etag == value.md5_hash:
240+
return value
240241
response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key)
241242
formatted_body = json.loads(response["Body"].read().decode("utf-8"))
243+
etag = response["ETag"]
242244
return JumpStartCachedS3ContentValue(
243245
formatted_file_content=utils.get_formatted_manifest(formatted_body),
244246
md5_hash=etag,
@@ -271,10 +273,10 @@ def get_header(
271273
header. If None, the highest compatible version is returned.
272274
"""
273275

274-
return self._get_header_impl(model_id, 0, semantic_version_str)
276+
return self._get_header_impl(model_id, semantic_version_str=semantic_version_str)
275277

276278
def _get_header_impl(
277-
self, model_id: str, attempt: int, semantic_version_str: Optional[str] = None
279+
self, model_id: str, attempt: Optional[int] = 0, semantic_version_str: Optional[str] = None
278280
) -> JumpStartModelHeader:
279281
"""Lower-level function to return header.
280282

src/sagemaker/jumpstart/types.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,12 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
144144
"min_sdk_version",
145145
"incremental_training_supported",
146146
"hosting_ecr_specs",
147-
"hosting_artifact_uri",
148-
"hosting_script_uri",
147+
"hosting_artifact_key",
148+
"hosting_script_key",
149149
"training_supported",
150150
"training_ecr_specs",
151-
"training_artifact_uri",
152-
"training_script_uri",
151+
"training_artifact_key",
152+
"training_script_key",
153153
"hyperparameters",
154154
]
155155

@@ -164,20 +164,20 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
164164
self.min_sdk_version: str = json_obj["min_sdk_version"]
165165
self.incremental_training_supported: bool = bool(json_obj["incremental_training_supported"])
166166
self.hosting_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(json_obj["hosting_ecr_specs"])
167-
self.hosting_artifact_uri: str = json_obj["hosting_artifact_uri"]
168-
self.hosting_script_uri: str = json_obj["hosting_script_uri"]
167+
self.hosting_artifact_key: str = json_obj["hosting_artifact_key"]
168+
self.hosting_script_key: str = json_obj["hosting_script_key"]
169169
self.training_supported: bool = bool(json_obj["training_supported"])
170170
if self.training_supported:
171171
self.training_ecr_specs: Optional[JumpStartECRSpecs] = JumpStartECRSpecs(
172172
json_obj["training_ecr_specs"]
173173
)
174-
self.training_artifact_uri: Optional[str] = json_obj["training_artifact_uri"]
175-
self.training_script_uri: Optional[str] = json_obj["training_script_uri"]
174+
self.training_artifact_key: Optional[str] = json_obj["training_artifact_key"]
175+
self.training_script_key: Optional[str] = json_obj["training_script_key"]
176176
self.hyperparameters: Optional[Dict[str, Any]] = json_obj["hyperparameters"]
177177
else:
178178
self.training_ecr_specs = (
179-
self.training_artifact_uri
180-
) = self.training_script_uri = self.hyperparameters = None
179+
self.training_artifact_key
180+
) = self.training_script_key = self.hyperparameters = None
181181

182182
def to_json(self) -> Dict[str, Any]:
183183
"""Returns json representation of JumpStartModelSpecs object."""

src/sagemaker/jumpstart/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def get_sagemaker_version() -> str:
8282
rejected.
8383
8484
Raises:
85-
RuntimeError: If the SageMaker version is not readable.
85+
RuntimeError: If the SageMaker version is not readable. An exception is also raised if
86+
the version cannot be parsed by ``semantic_version``.
8687
"""
8788
version = sagemaker.__version__
8889
parsed_version = None

src/sagemaker/utilities/cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,12 @@ def put(self, key: KeyType, value: Optional[ValType] = None) -> None:
134134
def _get_item(self, key: KeyType, fail_on_old_value: bool) -> ValType:
135135
"""Returns value from cache corresponding to key.
136136
137-
If ``fail_on_old_value``, a KeyError is thrown instead of a new value
137+
If ``fail_on_old_value``, a KeyError is raised instead of a new value
138138
getting fetched.
139139
140140
Args:
141141
key (KeyType): Key in cache to retrieve.
142-
fail_on_old_value (bool): True if a KeyError is thrown when the cache value
142+
fail_on_old_value (bool): True if a KeyError is raised when the cache value
143143
is old.
144144
145145
Raises:

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 91 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@
4949
"framework_version": "1.9.0",
5050
"py_version": "py3",
5151
},
52-
"hosting_artifact_uri": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
53-
"training_artifact_uri": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
54-
"hosting_script_uri": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
55-
"training_script_uri": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
52+
"hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
53+
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
54+
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
55+
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
5656
"hyperparameters": {
5757
"adam-learning-rate": {"type": "float", "default": 0.05, "min": 1e-08, "max": 1},
5858
"epochs": {"type": "int", "default": 3, "min": 1, "max": 1000},
@@ -334,30 +334,100 @@ def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client):
334334

335335

336336
def test_jumpstart_cache_handles_boto3_client_errors():
337+
# Testing get_object
337338
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
338339
stubbed_s3_client = Stubber(cache._s3_client)
339-
stubbed_s3_client.add_client_error("head_object", http_status_code=404)
340340
stubbed_s3_client.add_client_error("get_object", http_status_code=404)
341341
stubbed_s3_client.activate()
342342
with pytest.raises(botocore.exceptions.ClientError):
343343
cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4")
344344

345345
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
346346
stubbed_s3_client = Stubber(cache._s3_client)
347-
stubbed_s3_client.add_client_error("head_object", service_error_code="AccessDenied")
348347
stubbed_s3_client.add_client_error("get_object", service_error_code="AccessDenied")
349348
stubbed_s3_client.activate()
350349
with pytest.raises(botocore.exceptions.ClientError):
351350
cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4")
352351

353352
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
354353
stubbed_s3_client = Stubber(cache._s3_client)
355-
stubbed_s3_client.add_client_error("head_object", service_error_code="EndpointConnectionError")
356354
stubbed_s3_client.add_client_error("get_object", service_error_code="EndpointConnectionError")
357355
stubbed_s3_client.activate()
358356
with pytest.raises(botocore.exceptions.ClientError):
359357
cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4")
360358

359+
# Testing head_object:
360+
mock_now = datetime.datetime.fromtimestamp(1636730651.079551)
361+
with patch("datetime.datetime") as mock_datetime:
362+
mock_manifest_json = json.dumps(
363+
[
364+
{
365+
"model_id": "pytorch-ic-imagenet-inception-v3-classification-4",
366+
"version": "2.0.0",
367+
"min_version": "2.49.0",
368+
"spec_key": "community_models_specs/pytorch-ic-"
369+
"imagenet-inception-v3-classification-4/specs_v2.0.0.json",
370+
}
371+
]
372+
)
373+
374+
get_object_mocked_response = {
375+
"Body": botocore.response.StreamingBody(
376+
io.BytesIO(bytes(mock_manifest_json, "utf-8")),
377+
content_length=len(mock_manifest_json),
378+
),
379+
"ETag": "etag",
380+
}
381+
382+
mock_datetime.now.return_value = mock_now
383+
384+
cache1 = JumpStartModelsCache(
385+
s3_bucket_name="some_bucket", s3_cache_expiration_horizon=datetime.timedelta(hours=1)
386+
)
387+
stubbed_s3_client1 = Stubber(cache1._s3_client)
388+
389+
stubbed_s3_client1.add_response("get_object", copy.deepcopy(get_object_mocked_response))
390+
stubbed_s3_client1.activate()
391+
cache1.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
392+
393+
mock_datetime.now.return_value += datetime.timedelta(weeks=1)
394+
395+
stubbed_s3_client1.add_client_error("head_object", http_status_code=404)
396+
with pytest.raises(botocore.exceptions.ClientError):
397+
cache1.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
398+
399+
cache2 = JumpStartModelsCache(
400+
s3_bucket_name="some_bucket", s3_cache_expiration_horizon=datetime.timedelta(hours=1)
401+
)
402+
stubbed_s3_client2 = Stubber(cache2._s3_client)
403+
404+
stubbed_s3_client2.add_response("get_object", copy.deepcopy(get_object_mocked_response))
405+
stubbed_s3_client2.activate()
406+
cache2.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
407+
408+
mock_datetime.now.return_value += datetime.timedelta(weeks=1)
409+
410+
stubbed_s3_client2.add_client_error("head_object", service_error_code="AccessDenied")
411+
with pytest.raises(botocore.exceptions.ClientError):
412+
cache2.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
413+
414+
cache3 = JumpStartModelsCache(
415+
s3_bucket_name="some_bucket", s3_cache_expiration_horizon=datetime.timedelta(hours=1)
416+
)
417+
stubbed_s3_client3 = Stubber(cache3._s3_client)
418+
419+
stubbed_s3_client3.add_response("get_object", copy.deepcopy(get_object_mocked_response))
420+
stubbed_s3_client3.activate()
421+
cache3.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
422+
423+
mock_datetime.now.return_value += datetime.timedelta(weeks=1)
424+
425+
stubbed_s3_client3.add_client_error(
426+
"head_object", service_error_code="EndpointConnectionError"
427+
)
428+
with pytest.raises(botocore.exceptions.ClientError):
429+
cache3.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
430+
361431

362432
def test_jumpstart_cache_accepts_input_parameters():
363433

@@ -422,19 +492,18 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client):
422492
mock_boto3_client.return_value.get_object.return_value = {
423493
"Body": botocore.response.StreamingBody(
424494
io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json)
425-
)
495+
),
496+
"ETag": "hash1",
426497
}
427498
mock_boto3_client.return_value.head_object.return_value = {"ETag": "hash1"}
428499

429500
cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
430501

431-
# first time accessing cache should involve get_object and head_object
502+
# first time accessing cache should just involve get_object
432503
mock_boto3_client.return_value.get_object.assert_called_with(
433504
Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY
434505
)
435-
mock_boto3_client.return_value.head_object.assert_called_with(
436-
Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY
437-
)
506+
mock_boto3_client.return_value.head_object.assert_not_called()
438507

439508
mock_boto3_client.return_value.get_object.reset_mock()
440509
mock_boto3_client.return_value.head_object.reset_mock()
@@ -443,7 +512,8 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client):
443512
mock_boto3_client.return_value.get_object.return_value = {
444513
"Body": botocore.response.StreamingBody(
445514
io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json)
446-
)
515+
),
516+
"ETag": "hash1",
447517
}
448518
mock_boto3_client.return_value.head_object.return_value = {"ETag": "hash1"}
449519

@@ -465,7 +535,8 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client):
465535
mock_boto3_client.return_value.get_object.return_value = {
466536
"Body": botocore.response.StreamingBody(
467537
io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json)
468-
)
538+
),
539+
"ETag": "hash2",
469540
}
470541

471542
# invalidate cache
@@ -499,7 +570,8 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client):
499570
mock_boto3_client.return_value.get_object.return_value = {
500571
"Body": botocore.response.StreamingBody(
501572
io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json)
502-
)
573+
),
574+
"ETag": "etag",
503575
}
504576

505577
mock_boto3_client.return_value.head_object.return_value = {"ETag": "some-hash"}
@@ -514,9 +586,8 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client):
514586
mock_boto3_client.return_value.get_object.assert_called_with(
515587
Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY
516588
)
517-
mock_boto3_client.return_value.head_object.assert_called_with(
518-
Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY
519-
)
589+
mock_boto3_client.return_value.head_object.assert_not_called()
590+
520591
mock_boto3_client.assert_called_with("s3", region_name="my_region", config=client_config)
521592

522593
# test get_specs. manifest already in cache, so only s3 call will be to get specs.
@@ -527,7 +598,8 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client):
527598
mock_boto3_client.return_value.get_object.return_value = {
528599
"Body": botocore.response.StreamingBody(
529600
io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json)
530-
)
601+
),
602+
"ETag": "etag",
531603
}
532604
cache.get_specs(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
533605

tests/unit/sagemaker/jumpstart/test_types.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@ def test_jumpstart_model_specs():
6969
"framework_version": "1.9.0",
7070
"py_version": "py3",
7171
},
72-
"hosting_artifact_uri": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
73-
"training_artifact_uri": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
74-
"hosting_script_uri": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
75-
"training_script_uri": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
72+
"hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
73+
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
74+
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
75+
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
7676
"hyperparameters": {
7777
"adam-learning-rate": {"type": "float", "default": 0.05, "min": 1e-08, "max": 1},
7878
"epochs": {"type": "int", "default": 3, "min": 1, "max": 1000},
@@ -101,14 +101,14 @@ def test_jumpstart_model_specs():
101101
"py_version": "py3",
102102
}
103103
)
104-
assert specs1.hosting_artifact_uri == "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz"
105-
assert specs1.training_artifact_uri == "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz"
104+
assert specs1.hosting_artifact_key == "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz"
105+
assert specs1.training_artifact_key == "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz"
106106
assert (
107-
specs1.hosting_script_uri
107+
specs1.hosting_script_key
108108
== "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz"
109109
)
110110
assert (
111-
specs1.training_script_uri
111+
specs1.training_script_key
112112
== "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz"
113113
)
114114
assert specs1.hyperparameters == {

tests/unit/sagemaker/utilities/test_cache.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from typing import Optional, Union
1515
from mock.mock import MagicMock, patch
1616
import pytest
17-
import pickle
1817

1918

2019
from sagemaker.utilities import cache
@@ -54,16 +53,16 @@ def test_cache_invalidates_old_item():
5453
retrieval_function=retrieval_function,
5554
)
5655

57-
curr_time = datetime.datetime.fromtimestamp(1636730651.079551)
56+
mock_curr_time = datetime.datetime.fromtimestamp(1636730651.079551)
5857
with patch("datetime.datetime") as mock_datetime:
59-
mock_datetime.now.return_value = curr_time
58+
mock_datetime.now.return_value = mock_curr_time
6059
my_cache.put(5)
6160
mock_datetime.now.return_value += datetime.timedelta(milliseconds=2)
6261
with pytest.raises(KeyError):
6362
my_cache.get(5, False)
6463

6564
with patch("datetime.datetime") as mock_datetime:
66-
mock_datetime.now.return_value = curr_time
65+
mock_datetime.now.return_value = mock_curr_time
6766
my_cache.put(5)
6867
mock_datetime.now.return_value += datetime.timedelta(milliseconds=0.5)
6968
assert my_cache.get(5, False) == retrieval_function(key=5)
@@ -76,15 +75,15 @@ def test_cache_fetches_new_item():
7675
retrieval_function=retrieval_function,
7776
)
7877

79-
curr_time = datetime.datetime.fromtimestamp(1636730651.079551)
78+
mock_curr_time = datetime.datetime.fromtimestamp(1636730651.079551)
8079
with patch("datetime.datetime") as mock_datetime:
81-
mock_datetime.now.return_value = curr_time
80+
mock_datetime.now.return_value = mock_curr_time
8281
my_cache.put(5, 10)
8382
mock_datetime.now.return_value += datetime.timedelta(milliseconds=2)
8483
assert my_cache.get(5) == retrieval_function(key=5)
8584

8685
with patch("datetime.datetime") as mock_datetime:
87-
mock_datetime.now.return_value = curr_time
86+
mock_datetime.now.return_value = mock_curr_time
8887
my_cache.put(5, 10)
8988
mock_datetime.now.return_value += datetime.timedelta(milliseconds=0.5)
9089
assert my_cache.get(5, False) == 10
@@ -194,24 +193,3 @@ def test_cache_clear_and_contains():
194193
assert len(my_cache) == 0
195194
with pytest.raises(KeyError):
196195
my_cache.get(1, False)
197-
198-
199-
def test_cache_memory_usage():
200-
my_cache = cache.LRUCache[int, Union[int, str]](
201-
max_cache_items=10,
202-
expiration_horizon=datetime.timedelta(hours=1),
203-
retrieval_function=retrieval_function,
204-
)
205-
cache_size_bytes = []
206-
cache_size_bytes.append(len(pickle.dumps(my_cache)))
207-
for i in range(50):
208-
my_cache.put(i)
209-
cache_size_bytes.append(len(pickle.dumps(my_cache)))
210-
211-
max_cache_items_iter_cache_size = cache_size_bytes[10]
212-
past_capacity_iter_cache_size = cache_size_bytes[50]
213-
percent_difference_cache_size = (
214-
abs(past_capacity_iter_cache_size - max_cache_items_iter_cache_size)
215-
/ max_cache_items_iter_cache_size
216-
) * 100.0
217-
assert percent_difference_cache_size < 1

0 commit comments

Comments
 (0)