Skip to content

Commit 4e768c2

Browse files
committed
change: Replaced generic ValueError with custom subclass when reporting unexpected resource status
1 parent d7d33c4 commit 4e768c2

File tree

3 files changed

+155
-14
lines changed

3 files changed

+155
-14
lines changed

src/sagemaker/exceptions.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Custom exception classes for Sagemaker SDK"""
14+
from __future__ import absolute_import
15+
16+
17+
class UnexpectedStatusException(ValueError):
18+
"""Raised when resource status is not expected and thus not allowed for further execution"""
19+
20+
def __init__(self, message, allowed_statuses, actual_status):
21+
self.allowed_statuses = allowed_statuses
22+
self.actual_status = actual_status
23+
super(UnexpectedStatusException, self).__init__(message)

src/sagemaker/session.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
secondary_training_status_changed,
3636
secondary_training_status_message,
3737
)
38+
from sagemaker import exceptions
3839

3940
LOGGER = logging.getLogger("sagemaker")
4041

@@ -826,10 +827,12 @@ def wait_for_model_package(self, model_package_name, poll=5):
826827

827828
if status != "Completed":
828829
reason = desc.get("FailureReason", None)
829-
raise ValueError(
830-
"Error creating model package {}: {} Reason: {}".format(
831-
model_package_name, status, reason
832-
)
830+
raise exceptions.UnexpectedStatusException(
831+
message="Error creating model package {package}: {status} Reason: {reason}".format(
832+
package=model_package_name, status=status, reason=reason
833+
),
834+
allowed_statuses=["Completed"],
835+
actual_status=status,
833836
)
834837
return desc
835838

@@ -990,7 +993,7 @@ def wait_for_job(self, job, poll=5):
990993
(dict): Return value from the ``DescribeTrainingJob`` API.
991994
992995
Raises:
993-
ValueError: If the training job fails.
996+
exceptions.UnexpectedStatusException: If the training job fails.
994997
"""
995998
desc = _wait_until_training_done(
996999
lambda last_desc: _train_done(self.sagemaker_client, job, last_desc), None, poll
@@ -1009,7 +1012,7 @@ def wait_for_compilation_job(self, job, poll=5):
10091012
(dict): Return value from the ``DescribeCompilationJob`` API.
10101013
10111014
Raises:
1012-
ValueError: If the compilation job fails.
1015+
exceptions.UnexpectedStatusException: If the compilation job fails.
10131016
"""
10141017
desc = _wait_until(lambda: _compilation_job_status(self.sagemaker_client, job), poll)
10151018
self._check_job_status(job, desc, "CompilationJobStatus")
@@ -1026,7 +1029,7 @@ def wait_for_tuning_job(self, job, poll=5):
10261029
(dict): Return value from the ``DescribeHyperParameterTuningJob`` API.
10271030
10281031
Raises:
1029-
ValueError: If the hyperparameter tuning job fails.
1032+
exceptions.UnexpectedStatusException: If the hyperparameter tuning job fails.
10301033
"""
10311034
desc = _wait_until(lambda: _tuning_job_status(self.sagemaker_client, job), poll)
10321035
self._check_job_status(job, desc, "HyperParameterTuningJobStatus")
@@ -1043,23 +1046,23 @@ def wait_for_transform_job(self, job, poll=5):
10431046
(dict): Return value from the ``DescribeTransformJob`` API.
10441047
10451048
Raises:
1046-
ValueError: If the transform job fails.
1049+
exceptions.UnexpectedStatusException: If the transform job fails.
10471050
"""
10481051
desc = _wait_until(lambda: _transform_job_status(self.sagemaker_client, job), poll)
10491052
self._check_job_status(job, desc, "TransformJobStatus")
10501053
return desc
10511054

10521055
def _check_job_status(self, job, desc, status_key_name):
10531056
"""Check to see if the job completed successfully and, if not, construct and
1054-
raise a ValueError.
1057+
raise a exceptions.UnexpectedStatusException.
10551058
10561059
Args:
10571060
job (str): The name of the job to check.
10581061
desc (dict[str, str]): The result of ``describe_training_job()``.
10591062
status_key_name (str): Status key name to check for.
10601063
10611064
Raises:
1062-
ValueError: If the training job fails.
1065+
exceptions.UnexpectedStatusException: If the training job fails.
10631066
"""
10641067
status = desc[status_key_name]
10651068
# If the status is capital case, then convert it to Camel case
@@ -1068,7 +1071,13 @@ def _check_job_status(self, job, desc, status_key_name):
10681071
if status not in ("Completed", "Stopped"):
10691072
reason = desc.get("FailureReason", "(No reason provided)")
10701073
job_type = status_key_name.replace("JobStatus", " job")
1071-
raise ValueError("Error for {} {}: {} Reason: {}".format(job_type, job, status, reason))
1074+
raise exceptions.UnexpectedStatusException(
1075+
message="Error for {job_type} {job_name}: {status}. Reason: {reason}".format(
1076+
job_type=job_type, job_name=job, status=status, reason=reason
1077+
),
1078+
allowed_statuses=["Completed", "Stopped"],
1079+
actual_status=status,
1080+
)
10721081

10731082
def wait_for_endpoint(self, endpoint, poll=5):
10741083
"""Wait for an Amazon SageMaker endpoint deployment to complete.
@@ -1085,8 +1094,12 @@ def wait_for_endpoint(self, endpoint, poll=5):
10851094

10861095
if status != "InService":
10871096
reason = desc.get("FailureReason", None)
1088-
raise ValueError(
1089-
"Error hosting endpoint {}: {} Reason: {}".format(endpoint, status, reason)
1097+
raise exceptions.UnexpectedStatusException(
1098+
message="Error hosting endpoint {endpoint}: {status}. Reason: {reason}.".format(
1099+
endpoint=endpoint, status=status, reason=reason
1100+
),
1101+
allowed_statuses=["InService"],
1102+
actual_status=status,
10901103
)
10911104
return desc
10921105

@@ -1334,7 +1347,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
13341347
completion (default: 5).
13351348
13361349
Raises:
1337-
ValueError: If waiting and the training job fails.
1350+
exceptions.UnexpectedStatusException: If waiting and the training job fails.
13381351
"""
13391352

13401353
description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from mock import Mock, MagicMock
17+
import sagemaker
18+
19+
EXPANDED_ROLE = "arn:aws:iam::111111111111:role/ExpandedRole"
20+
REGION = "us-west-2"
21+
MODEL_PACKAGE_NAME = "my_model_package"
22+
JOB_NAME = "my_job_name"
23+
ENDPOINT_NAME = "the_point_of_end"
24+
25+
26+
def get_sagemaker_session(returns_status):
27+
boto_mock = Mock(name="boto_session", region_name=REGION)
28+
client_mock = Mock()
29+
client_mock.describe_model_package = MagicMock(
30+
return_value={"ModelPackageStatus": returns_status}
31+
)
32+
client_mock.describe_endpoint = MagicMock(return_value={"EndpointStatus": returns_status})
33+
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=client_mock)
34+
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
35+
return ims
36+
37+
38+
def test_does_not_raise_when_successfully_created_package():
39+
try:
40+
sagemaker_session = get_sagemaker_session(returns_status="Completed")
41+
sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME)
42+
except sagemaker.exceptions.UnexpectedStatusException:
43+
pytest.fail("UnexpectedStatusException was thrown while it should not")
44+
45+
46+
def test_raise_when_failed_created_package():
47+
try:
48+
sagemaker_session = get_sagemaker_session(returns_status="EnRoute")
49+
sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME)
50+
assert (
51+
False
52+
), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not"
53+
except Exception as e:
54+
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
55+
assert e.actual_status == "EnRoute"
56+
assert "Completed" in e.allowed_statuses
57+
58+
59+
def test_does_not_raise_when_correct_job_status():
60+
try:
61+
job = Mock()
62+
sagemaker_session = get_sagemaker_session(returns_status="Stopped")
63+
sagemaker_session._check_job_status(
64+
job, {"TransformationJobStatus": "Stopped"}, "TransformationJobStatus"
65+
)
66+
except sagemaker.exceptions.UnexpectedStatusException:
67+
pytest.fail("UnexpectedStatusException was thrown while it should not")
68+
69+
70+
def test_does_raise_when_incorrect_job_status():
71+
try:
72+
job = Mock()
73+
sagemaker_session = get_sagemaker_session(returns_status="Failed")
74+
sagemaker_session._check_job_status(
75+
job, {"TransformationJobStatus": "Failed"}, "TransformationJobStatus"
76+
)
77+
assert (
78+
False
79+
), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not"
80+
except Exception as e:
81+
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
82+
assert e.actual_status == "Failed"
83+
assert "Completed" in e.allowed_statuses
84+
assert "Stopped" in e.allowed_statuses
85+
86+
87+
def test_does_not_raise_when_successfully_deployed_endpoint():
88+
try:
89+
sagemaker_session = get_sagemaker_session(returns_status="InService")
90+
sagemaker_session.wait_for_endpoint(ENDPOINT_NAME)
91+
except sagemaker.exceptions.UnexpectedStatusException:
92+
pytest.fail("UnexpectedStatusException was thrown while it should not")
93+
94+
95+
def test_raise_when_failed_to_deploy_endpoint():
96+
try:
97+
sagemaker_session = get_sagemaker_session(returns_status="Failed")
98+
assert sagemaker_session.wait_for_endpoint(ENDPOINT_NAME)
99+
assert (
100+
False
101+
), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not"
102+
except Exception as e:
103+
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
104+
assert e.actual_status == "Failed"
105+
assert "InService" in e.allowed_statuses

0 commit comments

Comments
 (0)