Skip to content

Commit f6a1692

Browse files
evakraviknikure
authored andcommitted
feat: Enhance jumpstart logging for wildcard version identifiers and old model versions #1273 (#1301)
1 parent 69fb754 commit f6a1692

File tree

8 files changed

+264
-52
lines changed

8 files changed

+264
-52
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
2727
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
2828
JUMPSTART_DEFAULT_REGION_NAME,
29+
JUMPSTART_LOGGER,
2930
MODEL_ID_LIST_WEB_URL,
3031
)
32+
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
3133
from sagemaker.jumpstart.parameters import (
3234
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
3335
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
@@ -172,7 +174,7 @@ def _get_manifest_key_from_model_id_semantic_version(
172174

173175
manifest = self._s3_cache.get(
174176
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
175-
).formatted_content
177+
)[0].formatted_content
176178

177179
sm_version = utils.get_sagemaker_version()
178180

@@ -330,7 +332,7 @@ def _retrieval_function(
330332
if file_type == JumpStartS3FileType.SPECS:
331333
formatted_body, _ = self._get_json_file(s3_key, file_type)
332334
model_specs = JumpStartModelSpecs(formatted_body)
333-
utils.emit_logs_based_on_model_specs(model_specs, self.get_region())
335+
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
334336
return JumpStartCachedS3ContentValue(
335337
formatted_content=model_specs
336338
)
@@ -343,7 +345,7 @@ def get_manifest(self) -> List[JumpStartModelHeader]:
343345

344346
manifest_dict = self._s3_cache.get(
345347
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
346-
).formatted_content
348+
)[0].formatted_content
347349
manifest = list(manifest_dict.values()) # type: ignore
348350
return manifest
349351

@@ -403,10 +405,11 @@ def _get_header_impl(
403405

404406
versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get(
405407
JumpStartVersionedModelId(model_id, semantic_version_str)
406-
)
408+
)[0]
409+
407410
manifest = self._s3_cache.get(
408411
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
409-
).formatted_content
412+
)[0].formatted_content
410413
try:
411414
header = manifest[versioned_model_id] # type: ignore
412415
return header
@@ -427,10 +430,18 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
427430

428431
header = self.get_header(model_id, semantic_version_str)
429432
spec_key = header.spec_key
430-
specs = self._s3_cache.get(
433+
specs, cache_hit = self._s3_cache.get(
431434
JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key)
432-
).formatted_content
433-
return specs # type: ignore
435+
)
436+
if not cache_hit and "*" in semantic_version_str:
437+
JUMPSTART_LOGGER.warning(
438+
get_wildcard_model_version_msg(
439+
header.model_id,
440+
semantic_version_str,
441+
header.version
442+
)
443+
)
444+
return specs.formatted_content
434445

435446
def clear(self) -> None:
436447
"""Clears the model ID/version and s3 cache."""

src/sagemaker/jumpstart/exceptions.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,35 @@
3030
)
3131

3232

33+
_MAJOR_VERSION_WARNING_MSG = (
34+
"Note that models may have different input/output signatures after a major version upgrade."
35+
)
36+
37+
38+
def get_wildcard_model_version_msg(
39+
model_id: str, wildcard_model_version: str, full_model_version: str
40+
) -> str:
41+
"""Returns customer-facing message for using a model version with a wildcard character."""
42+
43+
return (
44+
f"Using model '{model_id}' with wildcard version identifier '{wildcard_model_version}'. "
45+
f"Please consider pinning to version '{full_model_version}' to "
46+
f"ensure stable results. {_MAJOR_VERSION_WARNING_MSG}"
47+
)
48+
49+
50+
def get_old_model_version_msg(
51+
model_id: str, current_model_version: str, latest_model_version: str
52+
) -> str:
53+
"""Returns customer-facing message associated with using an old model version."""
54+
55+
return (
56+
f"Using model '{model_id}' with old version '{current_model_version}'. "
57+
f"Please consider upgrading to version '{latest_model_version}'"
58+
f". {_MAJOR_VERSION_WARNING_MSG}"
59+
)
60+
61+
3362
class JumpStartHyperparametersError(ValueError):
3463
"""Exception raised for bad hyperparameters of a JumpStart model."""
3564

src/sagemaker/jumpstart/utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
from typing import Any, Dict, List, Optional, Union
1818
from urllib.parse import urlparse
19+
import boto3
1920
from packaging.version import Version
2021
import sagemaker
2122
from sagemaker.config.config_schema import (
@@ -31,6 +32,7 @@
3132
from sagemaker.jumpstart.exceptions import (
3233
DeprecatedJumpStartModelError,
3334
VulnerableJumpStartModelError,
35+
get_old_model_version_msg,
3436
)
3537
from sagemaker.jumpstart.types import (
3638
JumpStartModelHeader,
@@ -462,7 +464,9 @@ def update_inference_tags_with_jumpstart_training_tags(
462464
return inference_tags
463465

464466

465-
def emit_logs_based_on_model_specs(model_specs: JumpStartModelSpecs, region: str) -> None:
467+
def emit_logs_based_on_model_specs(
468+
model_specs: JumpStartModelSpecs, region: str, s3_client: boto3.client
469+
) -> None:
466470
"""Emits logs based on model specs and region."""
467471

468472
if model_specs.hosting_eula_key:
@@ -476,6 +480,24 @@ def emit_logs_based_on_model_specs(model_specs: JumpStartModelSpecs, region: str
476480
model_specs.hosting_eula_key,
477481
)
478482

483+
full_version: str = model_specs.version
484+
485+
models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
486+
region=region, s3_client=s3_client
487+
)
488+
max_version_for_model_id: Optional[str] = None
489+
for header in models_manifest_list:
490+
if header.model_id == model_specs.model_id:
491+
if max_version_for_model_id is None or Version(header.version) > Version(
492+
max_version_for_model_id
493+
):
494+
max_version_for_model_id = header.version
495+
496+
if full_version != max_version_for_model_id:
497+
constants.JUMPSTART_LOGGER.info(
498+
get_old_model_version_msg(model_specs.model_id, full_version, max_version_for_model_id)
499+
)
500+
479501
if model_specs.deprecated:
480502
deprecated_message = model_specs.deprecated_message or (
481503
"Using deprecated JumpStart model "

src/sagemaker/utilities/cache.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import datetime
1717
import collections
18-
from typing import TypeVar, Generic, Callable, Optional
18+
from typing import Tuple, TypeVar, Generic, Callable, Optional
1919

2020
KeyType = TypeVar("KeyType")
2121
ValType = TypeVar("ValType")
@@ -86,23 +86,26 @@ def clear(self) -> None:
8686
"""Deletes all elements from the cache."""
8787
self._lru_cache.clear()
8888

89-
def get(self, key: KeyType, data_source_fallback: Optional[bool] = True) -> ValType:
90-
"""Returns value corresponding to key in cache.
89+
def get(
90+
self, key: KeyType, data_source_fallback: Optional[bool] = True
91+
) -> Tuple[ValType, bool]:
92+
"""Returns value corresponding to key in cache and boolean indicating cache hit.
9193
9294
Args:
9395
key (KeyType): Key in cache to retrieve.
9496
data_source_fallback (Optional[bool]): True if data should be retrieved if
9597
it's stale or not in cache. Default: True.
96-
Raises:
97-
KeyError: If key is not found in cache or is outdated and
98-
``data_source_fallback`` is False.
98+
99+
Raises:
100+
KeyError: If key is not found in cache or is outdated and
101+
``data_source_fallback`` is False.
99102
"""
100103
if data_source_fallback:
101104
if key in self._lru_cache:
102-
return self._get_item(key, False)
105+
return self._get_item(key, False), True
103106
self.put(key)
104-
return self._get_item(key, False)
105-
return self._get_item(key, True)
107+
return self._get_item(key, False), False
108+
return self._get_item(key, True), True
106109

107110
def put(self, key: KeyType, value: Optional[ValType] = None) -> None:
108111
"""Adds key to cache using ``retrieval_function``.

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
BASE_MANIFEST,
4343
BASE_SPEC,
4444
)
45+
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
4546

4647

4748
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function)
@@ -525,11 +526,14 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client):
525526
)
526527

527528

529+
@patch("sagemaker.jumpstart.cache.utils.emit_logs_based_on_model_specs")
528530
@patch("boto3.client")
529-
def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client):
531+
def test_jumpstart_cache_makes_correct_s3_calls(
532+
mock_boto3_client, mock_emit_logs_based_on_model_specs
533+
):
530534

531535
# test get_header
532-
mock_json = json.dumps(
536+
mock_manifest_json = json.dumps(
533537
[
534538
{
535539
"model_id": "pytorch-ic-imagenet-inception-v3-classification-4",
@@ -542,17 +546,17 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client):
542546
)
543547
mock_boto3_client.return_value.get_object.return_value = {
544548
"Body": botocore.response.StreamingBody(
545-
io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json)
549+
io.BytesIO(bytes(mock_manifest_json, "utf-8")), content_length=len(mock_manifest_json)
546550
),
547551
"ETag": "etag",
548552
}
549553

550554
mock_boto3_client.return_value.head_object.return_value = {"ETag": "some-hash"}
551555

552-
bucket_name = "bucket_name"
556+
bucket_name = get_jumpstart_content_bucket("us-west-2")
553557
client_config = botocore.config.Config(signature_version="my_signature_version")
554558
cache = JumpStartModelsCache(
555-
s3_bucket_name=bucket_name, s3_client_config=client_config, region="my_region"
559+
s3_bucket_name=bucket_name, s3_client_config=client_config, region="us-west-2"
556560
)
557561
cache.get_header(
558562
model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*"
@@ -563,7 +567,7 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client):
563567
)
564568
mock_boto3_client.return_value.head_object.assert_not_called()
565569

566-
mock_boto3_client.assert_called_with("s3", region_name="my_region", config=client_config)
570+
mock_boto3_client.assert_called_with("s3", region_name="us-west-2", config=client_config)
567571

568572
# test get_specs. manifest already in cache, so only s3 call will be to get specs.
569573
mock_json = json.dumps(BASE_SPEC)
@@ -576,9 +580,22 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client):
576580
),
577581
"ETag": "etag",
578582
}
579-
cache.get_specs(
580-
model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*"
581-
)
583+
584+
with patch("logging.Logger.warning") as mocked_warning_log:
585+
cache.get_specs(
586+
model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*"
587+
)
588+
mocked_warning_log.assert_called_once_with(
589+
"Using model 'pytorch-ic-imagenet-inception-v3-classification-4' with wildcard "
590+
"version identifier '*'. Please consider pinning to version '2.0.0' to "
591+
"ensure stable results. Note that models may have different input/output "
592+
"signatures after a major version upgrade."
593+
)
594+
mocked_warning_log.reset_mock()
595+
cache.get_specs(
596+
model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*"
597+
)
598+
mocked_warning_log.assert_not_called()
582599

583600
mock_boto3_client.return_value.get_object.assert_called_with(
584601
Bucket=bucket_name,
@@ -595,10 +612,18 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache():
595612
cache.clear = MagicMock()
596613
cache._model_id_semantic_version_manifest_key_cache = MagicMock()
597614
cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [
598-
JumpStartVersionedModelId(
599-
"tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0"
615+
(
616+
JumpStartVersionedModelId(
617+
"tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0"
618+
),
619+
True,
620+
),
621+
(
622+
JumpStartVersionedModelId(
623+
"tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0"
624+
),
625+
True,
600626
),
601-
JumpStartVersionedModelId("tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0"),
602627
]
603628

604629
assert JumpStartModelHeader(
@@ -616,11 +641,17 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache():
616641
cache.clear.reset_mock()
617642

618643
cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [
619-
JumpStartVersionedModelId(
620-
"tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0"
644+
(
645+
JumpStartVersionedModelId(
646+
"tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0"
647+
),
648+
True,
621649
),
622-
JumpStartVersionedModelId(
623-
"tensorflow-ic-imagenet-inception-v3-classification-4", "987.0.0"
650+
(
651+
JumpStartVersionedModelId(
652+
"tensorflow-ic-imagenet-inception-v3-classification-4", "987.0.0"
653+
),
654+
True,
624655
),
625656
]
626657
with pytest.raises(KeyError):
@@ -735,6 +766,7 @@ def test_jumpstart_local_metadata_override_header(
735766
mocked_get_json_file_and_etag_from_s3.assert_not_called()
736767

737768

769+
@patch("sagemaker.jumpstart.cache.utils.emit_logs_based_on_model_specs")
738770
@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3")
739771
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
740772
@patch.dict(
@@ -747,7 +779,10 @@ def test_jumpstart_local_metadata_override_header(
747779
@patch("sagemaker.jumpstart.cache.os.path.isdir")
748780
@patch("builtins.open")
749781
def test_jumpstart_local_metadata_override_specs(
750-
mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock
782+
mocked_open: Mock,
783+
mocked_is_dir: Mock,
784+
mocked_get_json_file_and_etag_from_s3: Mock,
785+
mock_emit_logs_based_on_model_specs,
751786
):
752787

753788
mocked_open.side_effect = [
@@ -776,6 +811,7 @@ def test_jumpstart_local_metadata_override_specs(
776811
mocked_get_json_file_and_etag_from_s3.assert_not_called()
777812

778813

814+
@patch("sagemaker.jumpstart.cache.utils.emit_logs_based_on_model_specs")
779815
@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3")
780816
@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3")
781817
@patch.dict(
@@ -791,6 +827,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
791827
mocked_open: Mock,
792828
mocked_is_dir: Mock,
793829
mocked_get_json_file_and_etag_from_s3: Mock,
830+
mocked_emit_logs_based_on_model_specs,
794831
):
795832
model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0"
796833

0 commit comments

Comments
 (0)