Skip to content

change: improve logging and exception messages #4877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def _model_id_retrieval_function(
raise KeyError(error_msg)

error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. "
error_msg += f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. "
error_msg += "Specify a different model ID or try a different AWS Region. "
error_msg += f"For a list of available models, see {MODEL_ID_LIST_WEB_URL}. "

other_model_id_version = None
if model_type == JumpStartModelType.OPEN_WEIGHTS:
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/jumpstart/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
)

INVALID_MODEL_ID_ERROR_MSG = (
"Invalid model ID: '{model_id}'. Please visit "
f"{MODEL_ID_LIST_WEB_URL} for a list of valid model IDs. "
"Invalid model ID: '{model_id}'. Specify a different model ID or try a different AWS Region. "
f"For a list of available models, see {MODEL_ID_LIST_WEB_URL}. "
"The module `sagemaker.jumpstart.notebook_utils` contains utilities for "
"fetching model IDs. We recommend upgrading to the latest version of sagemaker "
"to get access to the most models."
Expand Down
9 changes: 7 additions & 2 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@
)
from sagemaker.session import Session
from sagemaker.config import load_sagemaker_config
from sagemaker.utils import resolve_value_from_config, TagsDict, get_instance_rate_per_hour
from sagemaker.utils import (
resolve_value_from_config,
TagsDict,
get_instance_rate_per_hour,
get_domain_for_region,
)
from sagemaker.workflow import is_pipeline_variable
from sagemaker.user_agent import get_user_agent_extra_suffix

Expand Down Expand Up @@ -553,7 +558,7 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
return (
f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). "
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
f"amazonaws.com{'.cn' if region.startswith('cn-') else ''}"
f"{get_domain_for_region(region)}"
f"/{model_specs.hosting_eula_key} for terms of use."
)

Expand Down
18 changes: 18 additions & 0 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@
from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string
from sagemaker.workflow.entities import PipelineVariable

ALTERNATE_DOMAINS = {
"cn-north-1": "amazonaws.com.cn",
"cn-northwest-1": "amazonaws.com.cn",
"us-iso-east-1": "c2s.ic.gov",
"us-isob-east-1": "sc2s.sgov.gov",
"us-isof-south-1": "csp.hci.ic.gov",
"us-isof-east-1": "csp.hci.ic.gov",
}

ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
MODEL_PACKAGE_ARN_PATTERN = (
r"arn:aws([a-z\-]*)?:sagemaker:([a-z0-9\-]*):([0-9]{12}):model-package/(.*)"
Expand Down Expand Up @@ -1905,3 +1914,12 @@ def remove_tag_with_key(key: str, tags: Optional[Tags]) -> Optional[Tags]:
if len(updated_tags) == 1:
return updated_tags[0]
return updated_tags


def get_domain_for_region(region: str) -> str:
"""Returns the domain for the given region.

Args:
region (str): AWS region name.
"""
return ALTERNATE_DOMAINS.get(region, "amazonaws.com")
10 changes: 2 additions & 8 deletions tests/unit/sagemaker/image_uris/expected_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,8 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

ALTERNATE_DOMAINS = {
"cn-north-1": "amazonaws.com.cn",
"cn-northwest-1": "amazonaws.com.cn",
"us-iso-east-1": "c2s.ic.gov",
"us-isob-east-1": "sc2s.sgov.gov",
"us-isof-south-1": "csp.hci.ic.gov",
"us-isof-east-1": "csp.hci.ic.gov",
}
from sagemaker.utils import ALTERNATE_DOMAINS

DOMAIN = "amazonaws.com"
IMAGE_URI_FORMAT = "{}.dkr.ecr.{}.{}/{}:{}"
MONITOR_URI_FORMAT = "{}.dkr.ecr.{}.{}/sagemaker-model-monitor-analyzer"
Expand Down
22 changes: 14 additions & 8 deletions tests/unit/sagemaker/jumpstart/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,26 +205,31 @@ def test_jumpstart_cache_get_header():
)
assert (
"Unable to find model manifest for 'pytorch-ic-imagenet-inception-v3-classification-4' with "
"version '3.*'. Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html "
"for updated list of models. Consider using model ID 'pytorch-ic-imagenet-inception-v3-"
"version '3.*'. Specify a different model ID or try a different AWS Region. "
"For a list of available models, see "
"https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. "
"Consider using model ID "
"'pytorch-ic-imagenet-inception-v3-"
"classification-4' with version '2.0.0'."
) in str(e.value)

with pytest.raises(KeyError) as e:
cache.get_header(model_id="pytorch-ic-", semantic_version_str="*")
assert (
"Unable to find model manifest for 'pytorch-ic-' with version '*'. "
"Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html "
"for updated list of models. "
"Specify a different model ID or try a different AWS Region. "
"For a list of available models, see "
"https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. "
"Did you mean to use model ID 'pytorch-ic-imagenet-inception-v3-classification-4'?"
) in str(e.value)

with pytest.raises(KeyError) as e:
cache.get_header(model_id="tensorflow-ic-", semantic_version_str="*")
assert (
"Unable to find model manifest for 'tensorflow-ic-' with version '*'. "
"Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html "
"for updated list of models. "
"Specify a different model ID or try a different AWS Region. For a list "
"of available models, see "
"https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. "
"Did you mean to use model ID 'tensorflow-ic-imagenet-inception-"
"v3-classification-4'?"
) in str(e.value)
Expand All @@ -237,8 +242,9 @@ def test_jumpstart_cache_get_header():
)
assert (
"Unable to find model manifest for 'ai21-summarize' with version '1.1.003'. "
"Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html "
"for updated list of models. "
"Specify a different model ID or try a different AWS Region. "
"For a list of available models, see "
"https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. "
"Did you mean to use model ID 'ai21-summarization'?"
) in str(e.value)

Expand Down
18 changes: 18 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,3 +2150,21 @@ def test_has_instance_rate_stat(stats, expected):
def test_deployment_config_response_data(data, expected):
out = utils.deployment_config_response_data(data)
assert out == expected


class TestGetEulaMessage(TestCase):
mock_model_specs = Mock(model_id="some-model-id", hosting_eula_key="some-eula-key")

def test_get_domain_for_region(self):
self.assertEqual(
utils.get_eula_message(self.mock_model_specs, "us-west-2"),
"Model 'some-model-id' requires accepting end-user license agreement (EULA). See"
" https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com/some-eula-key "
"for terms of use.",
)
self.assertEqual(
utils.get_eula_message(self.mock_model_specs, "cn-north-1"),
"Model 'some-model-id' requires accepting end-user license agreement (EULA). See"
" https://jumpstart-cache-prod-cn-north-1.s3.cn-north-1.amazonaws.com.cn/some-eula-key "
"for terms of use.",
)
13 changes: 13 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
camel_case_to_pascal_case,
deep_override_dict,
flatten_dict,
get_domain_for_region,
get_instance_type_family,
retry_with_backoff,
check_and_get_run_experiment_config,
Expand Down Expand Up @@ -2231,3 +2232,15 @@ def test_remove_non_existent_tag(self):
def test_remove_only_tag(self):
original_tags = [{"Key": "Tag1", "Value": "Value1"}]
self.assertIsNone(remove_tag_with_key("Tag1", original_tags))


class TestGetDomainForRegion(TestCase):
def test_get_domain_for_region(self):
self.assertEqual(get_domain_for_region("us-west-2"), "amazonaws.com")
self.assertEqual(get_domain_for_region("eu-west-1"), "amazonaws.com")
self.assertEqual(get_domain_for_region("ap-northeast-1"), "amazonaws.com")
self.assertEqual(get_domain_for_region("us-gov-west-1"), "amazonaws.com")
self.assertEqual(get_domain_for_region("cn-northwest-1"), "amazonaws.com.cn")
self.assertEqual(get_domain_for_region("us-iso-east-1"), "c2s.ic.gov")
self.assertEqual(get_domain_for_region("us-isob-east-1"), "sc2s.sgov.gov")
self.assertEqual(get_domain_for_region("invalid-region"), "amazonaws.com")