Skip to content

Commit 017160f

Browse files
authored
chore: excessive jumpstart bucket logging (#4053)
1 parent e0102e7 commit 017160f

File tree

3 files changed

+40
-15
lines changed

3 files changed

+40
-15
lines changed

src/sagemaker/jumpstart/accessors.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,20 @@ class JumpStartModelsAccessor(object):
4242
_cache: Optional[cache.JumpStartModelsCache] = None
4343
_curr_region = JUMPSTART_DEFAULT_REGION_NAME
4444

45+
_content_bucket: Optional[str] = None
46+
4547
_cache_kwargs: Dict[str, Any] = {}
4648

49+
@staticmethod
50+
def set_jumpstart_content_bucket(content_bucket: str) -> None:
51+
"""Sets JumpStart content bucket."""
52+
JumpStartModelsAccessor._content_bucket = content_bucket
53+
54+
@staticmethod
55+
def get_jumpstart_content_bucket() -> Optional[str]:
56+
"""Returns JumpStart content bucket."""
57+
return JumpStartModelsAccessor._content_bucket
58+
4759
@staticmethod
4860
def _validate_and_mutate_region_cache_kwargs(
4961
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None

src/sagemaker/jumpstart/utils.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,21 +72,37 @@ def get_jumpstart_content_bucket(
7272
RuntimeError: If JumpStart is not launched in ``region``.
7373
"""
7474

75+
old_content_bucket: Optional[
76+
str
77+
] = accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket()
78+
79+
info_logs: List[str] = []
80+
81+
bucket_to_return: Optional[str] = None
7582
if (
7683
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ
7784
and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0
7885
):
79-
bucket_override = os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]
80-
constants.JUMPSTART_LOGGER.info("Using JumpStart bucket override: '%s'", bucket_override)
81-
return bucket_override
82-
try:
83-
return constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket
84-
except KeyError:
85-
formatted_launched_regions_str = get_jumpstart_launched_regions_message()
86-
raise ValueError(
87-
f"Unable to get content bucket for JumpStart in {region} region. "
88-
f"{formatted_launched_regions_str}"
89-
)
86+
bucket_to_return = os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]
87+
info_logs.append(f"Using JumpStart bucket override: '{bucket_to_return}'")
88+
else:
89+
try:
90+
bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[
91+
region
92+
].content_bucket
93+
except KeyError:
94+
formatted_launched_regions_str = get_jumpstart_launched_regions_message()
95+
raise ValueError(
96+
f"Unable to get content bucket for JumpStart in {region} region. "
97+
f"{formatted_launched_regions_str}"
98+
)
99+
100+
accessors.JumpStartModelsAccessor.set_jumpstart_content_bucket(bucket_to_return)
101+
102+
if bucket_to_return != old_content_bucket:
103+
for info_log in info_logs:
104+
constants.JUMPSTART_LOGGER.info(info_log)
105+
return bucket_to_return
90106

91107

92108
def get_formatted_manifest(

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,7 @@ def test_get_jumpstart_content_bucket_override():
5757
with patch("logging.Logger.info") as mocked_info_log:
5858
random_region = "random_region"
5959
assert "some-val" == utils.get_jumpstart_content_bucket(random_region)
60-
mocked_info_log.assert_called_once_with(
61-
"Using JumpStart bucket override: '%s'",
62-
"some-val",
63-
)
60+
mocked_info_log.assert_called_once_with("Using JumpStart bucket override: 'some-val'")
6461

6562

6663
def test_get_jumpstart_launched_regions_message():

0 commit comments

Comments
 (0)