Skip to content

Commit 01f7f4c

Browse files
committed
change: Replaced generic ValueError with custom subclass when reporting unexpected resource status
1 parent 2d7bff8 commit 01f7f4c

File tree

3 files changed

+161
-38
lines changed

3 files changed

+161
-38
lines changed

src/sagemaker/exceptions.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
16+
class UnexpectedStatusException(ValueError):
17+
"""Raised when resource status is not expected and thus not allowed for further execution"""
18+
def __init__(self, message, allowed_statuses, actual_status):
19+
self.allowed_statuses = allowed_statuses
20+
self.actual_status = actual_status
21+
super(UnexpectedStatusException, self).__init__(message)

src/sagemaker/session.py

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
secondary_training_status_changed,
3535
secondary_training_status_message,
3636
)
37+
from sagemaker import exceptions
3738

3839
LOGGER = logging.getLogger("sagemaker")
3940

@@ -206,7 +207,7 @@ def default_bucket(self):
206207
Bucket=default_bucket, CreateBucketConfiguration={"LocationConstraint": region}
207208
)
208209

209-
LOGGER.info("Created S3 bucket: %s", default_bucket)
210+
LOGGER.info("Created S3 bucket: {}".format(default_bucket))
210211
except ClientError as e:
211212
error_code = e.response["Error"]["Code"]
212213
message = e.response["Error"]["Message"]
@@ -343,8 +344,8 @@ def train( # noqa: C901
343344
if encrypt_inter_container_traffic:
344345
train_request["EnableInterContainerTrafficEncryption"] = encrypt_inter_container_traffic
345346

346-
LOGGER.info("Creating training-job with name: %s", job_name)
347-
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
347+
LOGGER.info("Creating training-job with name: {}".format(job_name))
348+
LOGGER.debug("train request: {}".format(json.dumps(train_request, indent=4)))
348349
self.sagemaker_client.create_training_job(**train_request)
349350

350351
def compile_model(
@@ -379,7 +380,7 @@ def compile_model(
379380
if tags is not None:
380381
compilation_job_request["Tags"] = tags
381382

382-
LOGGER.info("Creating compilation-job with name: %s", job_name)
383+
LOGGER.info("Creating compilation-job with name: {}".format(job_name))
383384
self.sagemaker_client.create_compilation_job(**compilation_job_request)
384385

385386
def tune(
@@ -521,8 +522,8 @@ def tune(
521522
if encrypt_inter_container_traffic:
522523
tune_request["TrainingJobDefinition"]["EnableInterContainerTrafficEncryption"] = True
523524

524-
LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
525-
LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4))
525+
LOGGER.info("Creating hyperparameter tuning job with name: {}".format(job_name))
526+
LOGGER.debug("tune request: {}".format(json.dumps(tune_request, indent=4)))
526527
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)
527528

528529
def stop_tuning_job(self, name):
@@ -535,17 +536,18 @@ def stop_tuning_job(self, name):
535536
ClientError: If an error occurs while trying to stop the hyperparameter tuning job.
536537
"""
537538
try:
538-
LOGGER.info("Stopping tuning job: %s", name)
539+
LOGGER.info("Stopping tuning job: {}".format(name))
539540
self.sagemaker_client.stop_hyper_parameter_tuning_job(HyperParameterTuningJobName=name)
540541
except ClientError as e:
541542
error_code = e.response["Error"]["Code"]
542543
# allow to pass if the job already stopped
543544
if error_code == "ValidationException":
544-
LOGGER.info("Tuning job: %s is already stopped or not running.", name)
545+
LOGGER.info("Tuning job: {} is already stopped or not running.".format(name))
545546
else:
546547
LOGGER.error(
547-
"Error occurred while attempting to stop tuning job: %s. Please try again.",
548-
name,
548+
"Error occurred while attempting to stop tuning job: {}. Please try again.".format(
549+
name
550+
)
549551
)
550552
raise
551553

@@ -607,8 +609,8 @@ def transform(
607609
if data_processing is not None:
608610
transform_request["DataProcessing"] = data_processing
609611

610-
LOGGER.info("Creating transform job with name: %s", job_name)
611-
LOGGER.debug("Transform request: %s", json.dumps(transform_request, indent=4))
612+
LOGGER.info("Creating transform job with name: {}".format(job_name))
613+
LOGGER.debug("Transform request: {}".format(json.dumps(transform_request, indent=4)))
612614
self.sagemaker_client.create_transform_job(**transform_request)
613615

614616
def create_model(
@@ -680,8 +682,8 @@ def create_model(
680682
if enable_network_isolation:
681683
create_model_request["EnableNetworkIsolation"] = True
682684

683-
LOGGER.info("Creating model with name: %s", name)
684-
LOGGER.debug("CreateModel request: %s", json.dumps(create_model_request, indent=4))
685+
LOGGER.info("Creating model with name: {}".format(name))
686+
LOGGER.debug("CreateModel request: {}".format(json.dumps(create_model_request, indent=4)))
685687

686688
try:
687689
self.sagemaker_client.create_model(**create_model_request)
@@ -693,7 +695,7 @@ def create_model(
693695
error_code == "ValidationException"
694696
and "Cannot create already existing model" in message
695697
):
696-
LOGGER.warning("Using already existing model: %s", name)
698+
LOGGER.warning("Using already existing model: {}".format(name))
697699
else:
698700
raise
699701

@@ -764,14 +766,14 @@ def create_model_package_from_algorithm(self, name, description, algorithm_arn,
764766
},
765767
}
766768
try:
767-
LOGGER.info("Creating model package with name: %s", name)
769+
LOGGER.info("Creating model package with name: {}".format(name))
768770
self.sagemaker_client.create_model_package(**request)
769771
except ClientError as e:
770772
error_code = e.response["Error"]["Code"]
771773
message = e.response["Error"]["Message"]
772774

773775
if error_code == "ValidationException" and "ModelPackage already exists" in message:
774-
LOGGER.warning("Using already existing model package: %s", name)
776+
LOGGER.warning("Using already existing model package: {}".format(name))
775777
else:
776778
raise
777779

@@ -792,10 +794,10 @@ def wait_for_model_package(self, model_package_name, poll=5):
792794

793795
if status != "Completed":
794796
reason = desc.get("FailureReason", None)
795-
raise ValueError(
796-
"Error creating model package {}: {} Reason: {}".format(
797-
model_package_name, status, reason
798-
)
797+
raise exceptions.UnexpectedStatusException(
798+
message="Error creating model package {}: {} Reason: {}".format(model_package_name, status, reason),
799+
allowed_statuses=["Completed"],
800+
actual_status=status
799801
)
800802
return desc
801803

@@ -832,7 +834,7 @@ def create_endpoint_config(
832834
Returns:
833835
str: Name of the endpoint point configuration created.
834836
"""
835-
LOGGER.info("Creating endpoint-config with name %s", name)
837+
LOGGER.info("Creating endpoint-config with name {}".format(name))
836838

837839
tags = tags or []
838840

@@ -871,7 +873,7 @@ def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True):
871873
Returns:
872874
str: Name of the Amazon SageMaker ``Endpoint`` created.
873875
"""
874-
LOGGER.info("Creating endpoint with name %s", endpoint_name)
876+
LOGGER.info("Creating endpoint with name {}".format(endpoint_name))
875877

876878
tags = tags or []
877879

@@ -914,7 +916,7 @@ def delete_endpoint(self, endpoint_name):
914916
Args:
915917
endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to delete.
916918
"""
917-
LOGGER.info("Deleting endpoint with name: %s", endpoint_name)
919+
LOGGER.info("Deleting endpoint with name: {}".format(endpoint_name))
918920
self.sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
919921

920922
def delete_endpoint_config(self, endpoint_config_name):
@@ -923,7 +925,7 @@ def delete_endpoint_config(self, endpoint_config_name):
923925
Args:
924926
endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to delete.
925927
"""
926-
LOGGER.info("Deleting endpoint configuration with name: %s", endpoint_config_name)
928+
LOGGER.info("Deleting endpoint configuration with name: {}".format(endpoint_config_name))
927929
self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
928930

929931
def delete_model(self, model_name):
@@ -933,7 +935,7 @@ def delete_model(self, model_name):
933935
model_name (str): Name of the Amazon SageMaker model to delete.
934936
935937
"""
936-
LOGGER.info("Deleting model with name: %s", model_name)
938+
LOGGER.info("Deleting model with name: {}".format(model_name))
937939
self.sagemaker_client.delete_model(ModelName=model_name)
938940

939941
def wait_for_job(self, job, poll=5):
@@ -947,7 +949,7 @@ def wait_for_job(self, job, poll=5):
947949
(dict): Return value from the ``DescribeTrainingJob`` API.
948950
949951
Raises:
950-
ValueError: If the training job fails.
952+
exceptions.UnexpectedStatusException: If the training job fails.
951953
"""
952954
desc = _wait_until_training_done(
953955
lambda last_desc: _train_done(self.sagemaker_client, job, last_desc), None, poll
@@ -966,7 +968,7 @@ def wait_for_compilation_job(self, job, poll=5):
966968
(dict): Return value from the ``DescribeCompilationJob`` API.
967969
968970
Raises:
969-
ValueError: If the compilation job fails.
971+
exceptions.UnexpectedStatusException: If the compilation job fails.
970972
"""
971973
desc = _wait_until(lambda: _compilation_job_status(self.sagemaker_client, job), poll)
972974
self._check_job_status(job, desc, "CompilationJobStatus")
@@ -983,7 +985,7 @@ def wait_for_tuning_job(self, job, poll=5):
983985
(dict): Return value from the ``DescribeHyperParameterTuningJob`` API.
984986
985987
Raises:
986-
ValueError: If the hyperparameter tuning job fails.
988+
exceptions.UnexpectedStatusException: If the hyperparameter tuning job fails.
987989
"""
988990
desc = _wait_until(lambda: _tuning_job_status(self.sagemaker_client, job), poll)
989991
self._check_job_status(job, desc, "HyperParameterTuningJobStatus")
@@ -1000,23 +1002,23 @@ def wait_for_transform_job(self, job, poll=5):
10001002
(dict): Return value from the ``DescribeTransformJob`` API.
10011003
10021004
Raises:
1003-
ValueError: If the transform job fails.
1005+
exceptions.UnexpectedStatusException: If the transform job fails.
10041006
"""
10051007
desc = _wait_until(lambda: _transform_job_status(self.sagemaker_client, job), poll)
10061008
self._check_job_status(job, desc, "TransformJobStatus")
10071009
return desc
10081010

10091011
def _check_job_status(self, job, desc, status_key_name):
10101012
"""Check to see if the job completed successfully and, if not, construct and
1011-
raise a ValueError.
1013+
raise a exceptions.UnexpectedStatusException.
10121014
10131015
Args:
10141016
job (str): The name of the job to check.
10151017
desc (dict[str, str]): The result of ``describe_training_job()``.
10161018
status_key_name (str): Status key name to check for.
10171019
10181020
Raises:
1019-
ValueError: If the training job fails.
1021+
exceptions.UnexpectedStatusException: If the training job fails.
10201022
"""
10211023
status = desc[status_key_name]
10221024
# If the status is capital case, then convert it to Camel case
@@ -1025,7 +1027,11 @@ def _check_job_status(self, job, desc, status_key_name):
10251027
if status != "Completed" and status != "Stopped":
10261028
reason = desc.get("FailureReason", "(No reason provided)")
10271029
job_type = status_key_name.replace("JobStatus", " job")
1028-
raise ValueError("Error for {} {}: {} Reason: {}".format(job_type, job, status, reason))
1030+
raise exceptions.UnexpectedStatusException(
1031+
message="Error for {} {}: {} Reason: {}".format(job_type, job, status, reason),
1032+
allowed_statuses=["Completed", "Stopped"],
1033+
actual_status=status
1034+
)
10291035

10301036
def wait_for_endpoint(self, endpoint, poll=5):
10311037
"""Wait for an Amazon SageMaker endpoint deployment to complete.
@@ -1042,8 +1048,10 @@ def wait_for_endpoint(self, endpoint, poll=5):
10421048

10431049
if status != "InService":
10441050
reason = desc.get("FailureReason", None)
1045-
raise ValueError(
1046-
"Error hosting endpoint {}: {} Reason: {}".format(endpoint, status, reason)
1051+
raise exceptions.UnexpectedStatusException(
1052+
message="Error hosting endpoint {}: {} Reason: {}".format(endpoint, status, reason),
1053+
allowed_statuses=["InService"],
1054+
actual_status=status
10471055
)
10481056
return desc
10491057

@@ -1257,8 +1265,9 @@ def get_caller_identity_arn(self):
12571265
role = self.boto_session.client("iam").get_role(RoleName=role_name)["Role"]["Arn"]
12581266
except ClientError:
12591267
LOGGER.warning(
1260-
"Couldn't call 'get_role' to get Role ARN from role name %s to get Role path.",
1261-
role_name,
1268+
"Couldn't call 'get_role' to get Role ARN from role name {} to get Role path.".format(
1269+
role_name
1270+
)
12621271
)
12631272

12641273
return role
@@ -1276,7 +1285,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
12761285
poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
12771286
12781287
Raises:
1279-
ValueError: If waiting and the training job fails.
1288+
exceptions.UnexpectedStatusException: If waiting and the training job fails.
12801289
"""
12811290

12821291
description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from mock import Mock, MagicMock
17+
import sagemaker
18+
19+
EXPANDED_ROLE = 'arn:aws:iam::111111111111:role/ExpandedRole'
20+
REGION = 'us-west-2'
21+
MODEL_PACKAGE_NAME = 'my_model_package'
22+
JOB_NAME = 'my_job_name'
23+
ENDPOINT_NAME = 'the_point_of_end'
24+
25+
26+
def get_sagemaker_session(returns_status):
27+
boto_mock = Mock(name='boto_session', region_name=REGION)
28+
client_mock = Mock()
29+
client_mock.describe_model_package = MagicMock(return_value={'ModelPackageStatus': returns_status})
30+
client_mock.describe_endpoint = MagicMock(return_value={'EndpointStatus': returns_status})
31+
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=client_mock)
32+
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
33+
return ims
34+
35+
36+
def test_does_not_raise_when_successfully_created_package():
37+
try:
38+
sagemaker_session = get_sagemaker_session(returns_status='Completed')
39+
sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME)
40+
except sagemaker.exceptions.UnexpectedStatusException:
41+
pytest.fail("UnexpectedStatusException was thrown while it should not")
42+
43+
44+
def test_raise_when_failed_created_package():
45+
try:
46+
sagemaker_session = get_sagemaker_session(returns_status='EnRoute')
47+
sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME)
48+
assert False, 'sagemaker.exceptions.UnexpectedStatusException should have been raised but was not'
49+
except Exception as e:
50+
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
51+
assert e.actual_status == 'EnRoute'
52+
assert 'Completed' in e.allowed_statuses
53+
54+
55+
def test_does_not_raise_when_correct_job_status():
56+
try:
57+
job = Mock()
58+
sagemaker_session = get_sagemaker_session(returns_status='Stopped')
59+
sagemaker_session._check_job_status(job, {'TransformationJobStatus': 'Stopped'}, 'TransformationJobStatus')
60+
except sagemaker.exceptions.UnexpectedStatusException:
61+
pytest.fail("UnexpectedStatusException was thrown while it should not")
62+
63+
64+
def test_does_raise_when_incorrect_job_status():
65+
try:
66+
job = Mock()
67+
sagemaker_session = get_sagemaker_session(returns_status='Failed')
68+
sagemaker_session._check_job_status(job, {'TransformationJobStatus': 'Failed'}, 'TransformationJobStatus')
69+
assert False, 'sagemaker.exceptions.UnexpectedStatusException should have been raised but was not'
70+
except Exception as e:
71+
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
72+
assert e.actual_status == 'Failed'
73+
assert 'Completed' in e.allowed_statuses
74+
assert 'Stopped' in e.allowed_statuses
75+
76+
77+
def test_does_not_raise_when_successfully_deployed_endpoint():
78+
try:
79+
sagemaker_session = get_sagemaker_session(returns_status='InService')
80+
sagemaker_session.wait_for_endpoint(ENDPOINT_NAME)
81+
except sagemaker.exceptions.UnexpectedStatusException:
82+
pytest.fail("UnexpectedStatusException was thrown while it should not")
83+
84+
85+
def test_raise_when_failed_to_deploy_endpoint():
86+
try:
87+
sagemaker_session = get_sagemaker_session(returns_status='Failed')
88+
assert sagemaker_session.wait_for_endpoint(ENDPOINT_NAME)
89+
assert False, 'sagemaker.exceptions.UnexpectedStatusException should have been raised but was not'
90+
except Exception as e:
91+
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
92+
assert e.actual_status == 'Failed'
93+
assert 'InService' in e.allowed_statuses

0 commit comments

Comments
 (0)