Skip to content

Commit ccae97b

Browse files
authored
fix: minor jumpstart logging improvements (#1366)
1 parent 06a8557 commit ccae97b

File tree

10 files changed

+30
-40
lines changed

10 files changed

+30
-40
lines changed

src/sagemaker/jumpstart/exceptions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def get_wildcard_model_version_msg(
4242

4343
return (
4444
f"Using model '{model_id}' with wildcard version identifier '{wildcard_model_version}'. "
45-
f"Please consider pinning to version '{full_model_version}' to "
46-
f"ensure stable results. {_MAJOR_VERSION_WARNING_MSG}"
45+
f"You can pin to version '{full_model_version}' "
46+
f"for more stable results. {_MAJOR_VERSION_WARNING_MSG}"
4747
)
4848

4949

@@ -53,9 +53,9 @@ def get_old_model_version_msg(
5353
"""Returns customer-facing message associated with using an old model version."""
5454

5555
return (
56-
f"Using model '{model_id}' with old version '{current_model_version}'. "
57-
f"Please consider upgrading to version '{latest_model_version}'"
58-
f". {_MAJOR_VERSION_WARNING_MSG}"
56+
f"Using model '{model_id}' with version '{current_model_version}'. "
57+
f"You can upgrade to version '{latest_model_version}' to get the latest model "
58+
f"specifications. {_MAJOR_VERSION_WARNING_MSG}"
5959
)
6060

6161

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def _add_env_to_kwargs(
586586
)
587587

588588
for key, value in extra_env_vars.items():
589-
update_dict_if_key_not_present(
589+
kwargs.environment = update_dict_if_key_not_present(
590590
kwargs.environment,
591591
key,
592592
value,

src/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -755,13 +755,5 @@ def is_valid_model_id(
755755
if script == enums.JumpStartScriptScope.INFERENCE:
756756
return model_id in model_id_set
757757
if script == enums.JumpStartScriptScope.TRAINING:
758-
return (
759-
model_id in model_id_set
760-
and accessors.JumpStartModelsAccessor.get_model_specs(
761-
region=region,
762-
model_id=model_id,
763-
version=model_version,
764-
s3_client=s3_client,
765-
).training_supported
766-
)
758+
return model_id in model_id_set
767759
raise ValueError(f"Unsupported script: {script}")

tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414
import os
1515
import time
16+
import mock
1617

1718
import pytest
1819
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
@@ -127,7 +128,8 @@ def test_gated_model_training(setup):
127128
assert response is not None
128129

129130

130-
def test_instatiating_estimator_not_too_slow(setup):
131+
@mock.patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning")
132+
def test_instatiating_estimator(mock_warning_logger, setup):
131133

132134
model_id = "xgboost-classification-model"
133135

@@ -142,3 +144,5 @@ def test_instatiating_estimator_not_too_slow(setup):
142144
elapsed_time = time.perf_counter() - start_time
143145

144146
assert elapsed_time <= MAX_INIT_TIME_SECONDS
147+
148+
mock_warning_logger.assert_called_once()

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414
import os
1515
import time
16+
from unittest import mock
1617

1718
import pytest
1819

@@ -114,7 +115,8 @@ def test_model_package_arn_jumpstart_model(setup):
114115
assert response is not None
115116

116117

117-
def test_instatiating_model_not_too_slow(setup):
118+
@mock.patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning")
119+
def test_instatiating_model(mock_warning_logger, setup):
118120

119121
model_id = "catboost-regression-model"
120122

@@ -130,6 +132,8 @@ def test_instatiating_model_not_too_slow(setup):
130132

131133
assert elapsed_time <= MAX_INIT_TIME_SECONDS
132134

135+
mock_warning_logger.assert_called_once()
136+
133137

134138
def test_jumpstart_model_register(setup):
135139
model_id = "huggingface-txt2img-conflictx-complex-lineart"

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2777,7 +2777,8 @@
27772777
"g5": {
27782778
"regional_properties": {"image_uri": "$gpu_ecr_uri_1"},
27792779
"properties": {
2780-
"gated_model_key_env_var_value": "meta-training/train-meta-textgeneration-llama-2-7b.tar.gz"
2780+
"gated_model_key_env_var_value": "meta-training/train-meta-textgeneration-llama-2-7b.tar.gz",
2781+
"environment_variables": {"SELF_DESTRUCT": "true"},
27812782
},
27822783
},
27832784
"local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}},

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ def test_gated_model_non_model_package_s3_uri(
495495
encrypt_inter_container_traffic=True,
496496
enable_network_isolation=True,
497497
environment={
498+
"SELF_DESTRUCT": "true",
498499
"accept_eula": True,
499500
"SageMakerGatedModelS3Uri": "s3://top-secret-private-"
500501
"models-bucket/meta-training/train-meta-textgeneration-llama-2-7b.tar.gz",

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,8 +587,8 @@ def test_jumpstart_cache_makes_correct_s3_calls(
587587
)
588588
mocked_warning_log.assert_called_once_with(
589589
"Using model 'pytorch-ic-imagenet-inception-v3-classification-4' with wildcard "
590-
"version identifier '*'. Please consider pinning to version '2.0.0' to "
591-
"ensure stable results. Note that models may have different input/output "
590+
"version identifier '*'. You can pin to version '2.0.0' for more "
591+
"stable results. Note that models may have different input/output "
592592
"signatures after a major version upgrade."
593593
)
594594
mocked_warning_log.reset_mock()

tests/unit/sagemaker/jumpstart/test_exceptions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
def test_get_wildcard_model_version_msg():
2222
assert (
2323
"Using model 'mother_of_all_models' with wildcard version identifier '*'. "
24-
"Please consider pinning to version '1.2.3' to ensure stable results. "
24+
"You can pin to version '1.2.3' for more stable results. "
2525
"Note that models may have different input/output signatures after a "
2626
"major version upgrade."
2727
== get_wildcard_model_version_msg("mother_of_all_models", "*", "1.2.3")
@@ -30,8 +30,8 @@ def test_get_wildcard_model_version_msg():
3030

3131
def test_get_old_model_version_msg():
3232
assert (
33-
"Using model 'mother_of_all_models' with old version '1.0.0'. "
34-
"Please consider upgrading to version '1.2.3'. Note that models "
35-
"may have different input/output signatures after a major "
33+
"Using model 'mother_of_all_models' with version '1.0.0'. "
34+
"You can upgrade to version '1.2.3' to get the latest model specifications. "
35+
"Note that models may have different input/output signatures after a major "
3636
"version upgrade." == get_old_model_version_msg("mother_of_all_models", "1.0.0", "1.2.3")
3737
)

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -967,9 +967,9 @@ def test_jumpstart_old_model_spec(mock_get_manifest):
967967
)
968968

969969
mocked_info_log.assert_called_once_with(
970-
"Using model 'tensorflow-ic-imagenet-inception-v3-classification-4' with old version '1.0.0'. "
971-
"Please consider upgrading to version '1.1.0'. Note that models may have different "
972-
"input/output signatures after a major version upgrade."
970+
"Using model 'tensorflow-ic-imagenet-inception-v3-classification-4' with version '1.0.0'. "
971+
"You can upgrade to version '1.1.0' to get the latest model specifications. Note that models "
972+
"may have different input/output signatures after a major version upgrade."
973973
)
974974

975975
mocked_info_log.reset_mock()
@@ -1191,12 +1191,6 @@ def test_is_valid_model_id_true(
11911191
mock_get_manifest.assert_called_once_with(
11921192
region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value
11931193
)
1194-
mock_get_model_specs.assert_called_once_with(
1195-
region=JUMPSTART_DEFAULT_REGION_NAME,
1196-
model_id="bee",
1197-
version="*",
1198-
s3_client=mock_s3_client_value,
1199-
)
12001194

12011195
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest")
12021196
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -1254,13 +1248,7 @@ def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_mani
12541248
mock_get_model_specs.reset_mock()
12551249

12561250
mock_get_model_specs.return_value = Mock(training_supported=False)
1257-
self.assertFalse(utils.is_valid_model_id("ay", script=JumpStartScriptScope.TRAINING))
1251+
self.assertTrue(utils.is_valid_model_id("ay", script=JumpStartScriptScope.TRAINING))
12581252
mock_get_manifest.assert_called_once_with(
12591253
region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value
12601254
)
1261-
mock_get_model_specs.assert_called_once_with(
1262-
region=JUMPSTART_DEFAULT_REGION_NAME,
1263-
model_id="ay",
1264-
version="*",
1265-
s3_client=mock_s3_client_value,
1266-
)

0 commit comments

Comments
 (0)