Skip to content

Commit 6604acc

Browse files
committed
change: require version for cache get operations
1 parent 02b1a97 commit 6604acc

File tree

2 files changed

+76
-57
lines changed

2 files changed

+76
-57
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -262,31 +262,32 @@ def get_manifest(self) -> List[JumpStartModelHeader]:
262262
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
263263
).formatted_file_content.values()
264264

265-
def get_header(
266-
self, model_id: str, semantic_version_str: Optional[str] = None
267-
) -> JumpStartModelHeader:
265+
def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader:
268266
"""Return header for a given JumpStart model id and semantic version.
269267
270268
Args:
271269
model_id (str): model id for which to get a header.
272-
semantic_version_str (Optional[str]): The semantic version for which to get a
273-
header. If None, the highest compatible version is returned.
270+
semantic_version_str (str): The semantic version for which to get a
271+
header.
274272
"""
275273

276274
return self._get_header_impl(model_id, semantic_version_str=semantic_version_str)
277275

278276
def _get_header_impl(
279-
self, model_id: str, attempt: Optional[int] = 0, semantic_version_str: Optional[str] = None
277+
self,
278+
model_id: str,
279+
semantic_version_str: str,
280+
attempt: Optional[int] = 0,
280281
) -> JumpStartModelHeader:
281282
"""Lower-level function to return header.
282283
283284
Allows a single retry if the cache is old.
284285
285286
Args:
286287
model_id (str): model id for which to get a header.
287-
attempt (int): attempt number at retrieving a header.
288-
semantic_version_str (Optional[str]): The semantic version for which to get a
289-
header. If None, the highest compatible version is returned.
288+
semantic_version_str (str): The semantic version for which to get a
289+
header.
290+
attempt (Optional[int]): attempt number at retrieving a header.
290291
"""
291292

292293
versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get(
@@ -301,17 +302,15 @@ def _get_header_impl(
301302
if attempt > 0:
302303
raise
303304
self.clear()
304-
return self._get_header_impl(model_id, attempt + 1, semantic_version_str)
305+
return self._get_header_impl(model_id, semantic_version_str, attempt + 1)
305306

306-
def get_specs(
307-
self, model_id: str, semantic_version_str: Optional[str] = None
308-
) -> JumpStartModelSpecs:
307+
def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelSpecs:
309308
"""Return specs for a given JumpStart model id and semantic version.
310309
311310
Args:
312311
model_id (str): model id for which to get specs.
313-
semantic_version_str (Optional[str]): The semantic version for which to get
314-
specs. If None, the highest compatible version is returned.
312+
semantic_version_str (str): The semantic version for which to get
313+
specs.
315314
"""
316315

317316
header = self.get_header(model_id, semantic_version_str)

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 62 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -136,33 +136,19 @@ def test_jumpstart_cache_get_header():
136136

137137
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
138138

139-
assert (
140-
JumpStartModelHeader(
141-
{
142-
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
143-
"version": "2.0.0",
144-
"min_version": "2.49.0",
145-
"spec_key": "community_models_specs/tensorflow-ic"
146-
"-imagenet-inception-v3-classification-4/specs_v2.0.0.json",
147-
}
148-
)
149-
== cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4")
139+
assert JumpStartModelHeader(
140+
{
141+
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
142+
"version": "2.0.0",
143+
"min_version": "2.49.0",
144+
"spec_key": "community_models_specs/tensorflow-ic"
145+
"-imagenet-inception-v3-classification-4/specs_v2.0.0.json",
146+
}
147+
) == cache.get_header(
148+
model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version_str="*"
150149
)
151150

152151
# See if we can make the same query 2 times consecutively
153-
assert (
154-
JumpStartModelHeader(
155-
{
156-
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
157-
"version": "2.0.0",
158-
"min_version": "2.49.0",
159-
"spec_key": "community_models_specs/tensorflow-ic"
160-
"-imagenet-inception-v3-classification-4/specs_v2.0.0.json",
161-
}
162-
)
163-
== cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4")
164-
)
165-
166152
assert JumpStartModelHeader(
167153
{
168154
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
@@ -278,6 +264,7 @@ def test_jumpstart_cache_get_header():
278264
with pytest.raises(KeyError):
279265
cache.get_header(
280266
model_id="tensorflow-ic-imagenet-inception-v3-classification-4-bak",
267+
semantic_version_str="*",
281268
)
282269

283270

@@ -340,21 +327,30 @@ def test_jumpstart_cache_handles_boto3_client_errors():
340327
stubbed_s3_client.add_client_error("get_object", http_status_code=404)
341328
stubbed_s3_client.activate()
342329
with pytest.raises(botocore.exceptions.ClientError):
343-
cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4")
330+
cache.get_header(
331+
model_id="tensorflow-ic-imagenet-inception-v3-classification-4",
332+
semantic_version_str="*",
333+
)
344334

345335
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
346336
stubbed_s3_client = Stubber(cache._s3_client)
347337
stubbed_s3_client.add_client_error("get_object", service_error_code="AccessDenied")
348338
stubbed_s3_client.activate()
349339
with pytest.raises(botocore.exceptions.ClientError):
350-
cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4")
340+
cache.get_header(
341+
model_id="tensorflow-ic-imagenet-inception-v3-classification-4",
342+
semantic_version_str="*",
343+
)
351344

352345
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
353346
stubbed_s3_client = Stubber(cache._s3_client)
354347
stubbed_s3_client.add_client_error("get_object", service_error_code="EndpointConnectionError")
355348
stubbed_s3_client.activate()
356349
with pytest.raises(botocore.exceptions.ClientError):
357-
cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4")
350+
cache.get_header(
351+
model_id="tensorflow-ic-imagenet-inception-v3-classification-4",
352+
semantic_version_str="*",
353+
)
358354

359355
# Testing head_object:
360356
mock_now = datetime.datetime.fromtimestamp(1636730651.079551)
@@ -388,13 +384,18 @@ def test_jumpstart_cache_handles_boto3_client_errors():
388384

389385
stubbed_s3_client1.add_response("get_object", copy.deepcopy(get_object_mocked_response))
390386
stubbed_s3_client1.activate()
391-
cache1.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
387+
cache1.get_header(
388+
model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*"
389+
)
392390

393391
mock_datetime.now.return_value += datetime.timedelta(weeks=1)
394392

395393
stubbed_s3_client1.add_client_error("head_object", http_status_code=404)
396394
with pytest.raises(botocore.exceptions.ClientError):
397-
cache1.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
395+
cache1.get_header(
396+
model_id="pytorch-ic-imagenet-inception-v3-classification-4",
397+
semantic_version_str="*",
398+
)
398399

399400
cache2 = JumpStartModelsCache(
400401
s3_bucket_name="some_bucket", s3_cache_expiration_horizon=datetime.timedelta(hours=1)
@@ -403,13 +404,18 @@ def test_jumpstart_cache_handles_boto3_client_errors():
403404

404405
stubbed_s3_client2.add_response("get_object", copy.deepcopy(get_object_mocked_response))
405406
stubbed_s3_client2.activate()
406-
cache2.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
407+
cache2.get_header(
408+
model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*"
409+
)
407410

408411
mock_datetime.now.return_value += datetime.timedelta(weeks=1)
409412

410413
stubbed_s3_client2.add_client_error("head_object", service_error_code="AccessDenied")
411414
with pytest.raises(botocore.exceptions.ClientError):
412-
cache2.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
415+
cache2.get_header(
416+
model_id="pytorch-ic-imagenet-inception-v3-classification-4",
417+
semantic_version_str="*",
418+
)
413419

414420
cache3 = JumpStartModelsCache(
415421
s3_bucket_name="some_bucket", s3_cache_expiration_horizon=datetime.timedelta(hours=1)
@@ -418,15 +424,20 @@ def test_jumpstart_cache_handles_boto3_client_errors():
418424

419425
stubbed_s3_client3.add_response("get_object", copy.deepcopy(get_object_mocked_response))
420426
stubbed_s3_client3.activate()
421-
cache3.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
427+
cache3.get_header(
428+
model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*"
429+
)
422430

423431
mock_datetime.now.return_value += datetime.timedelta(weeks=1)
424432

425433
stubbed_s3_client3.add_client_error(
426434
"head_object", service_error_code="EndpointConnectionError"
427435
)
428436
with pytest.raises(botocore.exceptions.ClientError):
429-
cache3.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
437+
cache3.get_header(
438+
model_id="pytorch-ic-imagenet-inception-v3-classification-4",
439+
semantic_version_str="*",
440+
)
430441

431442

432443
def test_jumpstart_cache_accepts_input_parameters():
@@ -497,7 +508,9 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client):
497508
}
498509
mock_boto3_client.return_value.head_object.return_value = {"ETag": "hash1"}
499510

500-
cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
511+
cache.get_header(
512+
model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*"
513+
)
501514

502515
# first time accessing cache should just involve get_object
503516
mock_boto3_client.return_value.get_object.assert_called_with(
@@ -520,7 +533,9 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client):
520533
# invalidate cache
521534
mock_datetime.now.return_value += datetime.timedelta(hours=2)
522535

523-
cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
536+
cache.get_header(
537+
model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*"
538+
)
524539

525540
mock_boto3_client.return_value.head_object.assert_called_with(
526541
Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY
@@ -542,7 +557,9 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client):
542557
# invalidate cache
543558
mock_datetime.now.return_value += datetime.timedelta(hours=2)
544559

545-
cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
560+
cache.get_header(
561+
model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*"
562+
)
546563

547564
mock_boto3_client.return_value.get_object.assert_called_with(
548565
Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY
@@ -581,7 +598,9 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client):
581598
cache = JumpStartModelsCache(
582599
s3_bucket_name=bucket_name, s3_client_config=client_config, region="my_region"
583600
)
584-
cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
601+
cache.get_header(
602+
model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*"
603+
)
585604

586605
mock_boto3_client.return_value.get_object.assert_called_with(
587606
Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY
@@ -601,7 +620,9 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client):
601620
),
602621
"ETag": "etag",
603622
}
604-
cache.get_specs(model_id="pytorch-ic-imagenet-inception-v3-classification-4")
623+
cache.get_specs(
624+
model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*"
625+
)
605626

606627
mock_boto3_client.return_value.get_object.assert_called_with(
607628
Bucket=bucket_name,
@@ -633,7 +654,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache():
633654
"imagenet-inception-v3-classification-4/specs_v1.0.0.json",
634655
}
635656
) == cache.get_header(
636-
model_id="tensorflow-ic-imagenet-inception-v3-classification-4",
657+
model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version_str="*"
637658
)
638659
cache.clear.assert_called_once()
639660
cache.clear.reset_mock()
@@ -649,6 +670,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache():
649670
with pytest.raises(KeyError):
650671
cache.get_header(
651672
model_id="tensorflow-ic-imagenet-inception-v3-classification-4",
673+
semantic_version_str="*",
652674
)
653675
cache.clear.assert_called_once()
654676

@@ -683,9 +705,7 @@ def test_jumpstart_cache_get_specs():
683705
)
684706

685707
with pytest.raises(KeyError):
686-
cache.get_specs(
687-
model_id=model_id + "bak",
688-
)
708+
cache.get_specs(model_id=model_id + "bak", semantic_version_str="*")
689709

690710
with pytest.raises(KeyError):
691711
cache.get_specs(model_id=model_id, semantic_version_str="9.*")

0 commit comments

Comments
 (0)