Skip to content

Commit 6b783c3

Browse files
committed
fix: gated models unsupported region
1 parent 63f39e1 commit 6b783c3

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

src/sagemaker/jumpstart/artifacts/model_packages.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ def _retrieve_model_package_arn(
7272

7373
regional_arn = model_specs.hosting_model_package_arns.get(region)
7474

75+
if regional_arn is None:
76+
raise ValueError(
77+
f"Model package arn for '{model_id}' not supported in {region}. "
78+
"Please try one of the following regions: "
79+
f"{', '.join(model_specs.hosting_model_package_arns.keys())}."
80+
)
81+
7582
return regional_arn
7683

7784
raise NotImplementedError(f"Model Package ARN not supported for scope: '{scope}'")
@@ -130,6 +137,13 @@ def _retrieve_model_package_model_artifact_s3_uri(
130137

131138
model_s3_uri = model_specs.training_model_package_artifact_uris.get(region)
132139

140+
if model_s3_uri is None:
141+
raise ValueError(
142+
f"Model package artifact s3 uri for '{model_id}' not supported in {region}. "
143+
"Please try one of the following regions: "
144+
f"{', '.join(model_specs.training_model_package_artifact_uris.keys())}."
145+
)
146+
133147
return model_s3_uri
134148

135149
raise NotImplementedError(f"Model Package Artifact URI not supported for scope: '{scope}'")

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,49 @@ def test_gated_model_s3_uri(
353353
use_compiled_model=False,
354354
)
355355

356+
@mock.patch("sagemaker.utils.sagemaker_timestamp")
357+
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
358+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
359+
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
360+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
361+
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
362+
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
363+
@mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy")
364+
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
365+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
366+
def test_jumpstart_model_package_artifact_s3_uri_unsupported_region(
367+
self,
368+
mock_estimator_deploy: mock.Mock,
369+
mock_estimator_fit: mock.Mock,
370+
mock_estimator_init: mock.Mock,
371+
mock_get_model_specs: mock.Mock,
372+
mock_session_estimator: mock.Mock,
373+
mock_session_model: mock.Mock,
374+
mock_is_valid_model_id: mock.Mock,
375+
mock_timestamp: mock.Mock,
376+
):
377+
mock_estimator_deploy.return_value = default_predictor
378+
379+
mock_timestamp.return_value = "8675309"
380+
381+
mock_is_valid_model_id.return_value = True
382+
383+
model_id, _ = "js-gated-artifact-trainable-model", "*"
384+
385+
mock_get_model_specs.side_effect = get_special_model_spec
386+
387+
mock_session_estimator.return_value = sagemaker_session
388+
mock_session_model.return_value = sagemaker_session
389+
390+
with pytest.raises(ValueError) as e:
391+
JumpStartEstimator(model_id=model_id, region="eu-north-1")
392+
393+
assert (
394+
str(e.value) == "Model package artifact s3 uri for 'js-gated-artifact-trainable-model' "
395+
"not supported in eu-north-1. Please try one of the following regions: "
396+
"us-west-2, us-east-1, eu-west-1, ap-southeast-1."
397+
)
398+
356399
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
357400
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
358401
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,32 @@ def test_jumpstart_model_package_arn_override(
646646
},
647647
)
648648

649+
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
650+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
651+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
652+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
653+
def test_jumpstart_model_package_arn_unsupported_region(
654+
self,
655+
mock_get_model_specs: mock.Mock,
656+
mock_session: mock.Mock,
657+
mock_is_valid_model_id: mock.Mock,
658+
):
659+
660+
mock_is_valid_model_id.return_value = True
661+
662+
model_id, _ = "js-model-package-arn", "*"
663+
664+
mock_get_model_specs.side_effect = get_special_model_spec
665+
666+
mock_session.return_value = MagicMock(sagemaker_config={})
667+
668+
with pytest.raises(ValueError) as e:
669+
JumpStartModel(model_id=model_id, region="us-east-2")
670+
assert (
671+
str(e.value) == "Model package arn for 'js-model-package-arn' not supported in "
672+
"us-east-2. Please try one of the following regions: us-west-2, us-east-1."
673+
)
674+
649675

650676
def test_jumpstart_model_requires_model_id():
651677
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)