35
35
secondary_training_status_changed ,
36
36
secondary_training_status_message ,
37
37
)
38
+ from sagemaker import exceptions
38
39
39
40
LOGGER = logging .getLogger ("sagemaker" )
40
41
@@ -826,10 +827,12 @@ def wait_for_model_package(self, model_package_name, poll=5):
826
827
827
828
if status != "Completed" :
828
829
reason = desc .get ("FailureReason" , None )
829
- raise ValueError (
830
- "Error creating model package {}: {} Reason: {}" .format (
831
- model_package_name , status , reason
832
- )
830
+ raise exceptions .UnexpectedStatusException (
831
+ message = "Error creating model package {package}: {status} Reason: {reason}" .format (
832
+ package = model_package_name , status = status , reason = reason
833
+ ),
834
+ allowed_statuses = ["Completed" ],
835
+ actual_status = status ,
833
836
)
834
837
return desc
835
838
@@ -990,7 +993,7 @@ def wait_for_job(self, job, poll=5):
990
993
(dict): Return value from the ``DescribeTrainingJob`` API.
991
994
992
995
Raises:
993
- ValueError : If the training job fails.
996
+ exceptions.UnexpectedStatusException : If the training job fails.
994
997
"""
995
998
desc = _wait_until_training_done (
996
999
lambda last_desc : _train_done (self .sagemaker_client , job , last_desc ), None , poll
@@ -1009,7 +1012,7 @@ def wait_for_compilation_job(self, job, poll=5):
1009
1012
(dict): Return value from the ``DescribeCompilationJob`` API.
1010
1013
1011
1014
Raises:
1012
- ValueError : If the compilation job fails.
1015
+ exceptions.UnexpectedStatusException : If the compilation job fails.
1013
1016
"""
1014
1017
desc = _wait_until (lambda : _compilation_job_status (self .sagemaker_client , job ), poll )
1015
1018
self ._check_job_status (job , desc , "CompilationJobStatus" )
@@ -1026,7 +1029,7 @@ def wait_for_tuning_job(self, job, poll=5):
1026
1029
(dict): Return value from the ``DescribeHyperParameterTuningJob`` API.
1027
1030
1028
1031
Raises:
1029
- ValueError : If the hyperparameter tuning job fails.
1032
+ exceptions.UnexpectedStatusException : If the hyperparameter tuning job fails.
1030
1033
"""
1031
1034
desc = _wait_until (lambda : _tuning_job_status (self .sagemaker_client , job ), poll )
1032
1035
self ._check_job_status (job , desc , "HyperParameterTuningJobStatus" )
@@ -1043,23 +1046,23 @@ def wait_for_transform_job(self, job, poll=5):
1043
1046
(dict): Return value from the ``DescribeTransformJob`` API.
1044
1047
1045
1048
Raises:
1046
- ValueError : If the transform job fails.
1049
+ exceptions.UnexpectedStatusException : If the transform job fails.
1047
1050
"""
1048
1051
desc = _wait_until (lambda : _transform_job_status (self .sagemaker_client , job ), poll )
1049
1052
self ._check_job_status (job , desc , "TransformJobStatus" )
1050
1053
return desc
1051
1054
1052
1055
def _check_job_status (self , job , desc , status_key_name ):
1053
1056
"""Check to see if the job completed successfully and, if not, construct and
1054
- raise a ValueError .
1057
+ raise a exceptions.UnexpectedStatusException .
1055
1058
1056
1059
Args:
1057
1060
job (str): The name of the job to check.
1058
1061
desc (dict[str, str]): The result of ``describe_training_job()``.
1059
1062
status_key_name (str): Status key name to check for.
1060
1063
1061
1064
Raises:
1062
- ValueError : If the training job fails.
1065
+ exceptions.UnexpectedStatusException : If the training job fails.
1063
1066
"""
1064
1067
status = desc [status_key_name ]
1065
1068
# If the status is capital case, then convert it to Camel case
@@ -1068,7 +1071,13 @@ def _check_job_status(self, job, desc, status_key_name):
1068
1071
if status not in ("Completed" , "Stopped" ):
1069
1072
reason = desc .get ("FailureReason" , "(No reason provided)" )
1070
1073
job_type = status_key_name .replace ("JobStatus" , " job" )
1071
- raise ValueError ("Error for {} {}: {} Reason: {}" .format (job_type , job , status , reason ))
1074
+ raise exceptions .UnexpectedStatusException (
1075
+ message = "Error for {job_type} {job_name}: {status}. Reason: {reason}" .format (
1076
+ job_type = job_type , job_name = job , status = status , reason = reason
1077
+ ),
1078
+ allowed_statuses = ["Completed" , "Stopped" ],
1079
+ actual_status = status ,
1080
+ )
1072
1081
1073
1082
def wait_for_endpoint (self , endpoint , poll = 5 ):
1074
1083
"""Wait for an Amazon SageMaker endpoint deployment to complete.
@@ -1085,8 +1094,12 @@ def wait_for_endpoint(self, endpoint, poll=5):
1085
1094
1086
1095
if status != "InService" :
1087
1096
reason = desc .get ("FailureReason" , None )
1088
- raise ValueError (
1089
- "Error hosting endpoint {}: {} Reason: {}" .format (endpoint , status , reason )
1097
+ raise exceptions .UnexpectedStatusException (
1098
+ message = "Error hosting endpoint {endpoint}: {status}. Reason: {reason}." .format (
1099
+ endpoint = endpoint , status = status , reason = reason
1100
+ ),
1101
+ allowed_statuses = ["InService" ],
1102
+ actual_status = status ,
1090
1103
)
1091
1104
return desc
1092
1105
@@ -1334,7 +1347,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
1334
1347
completion (default: 5).
1335
1348
1336
1349
Raises:
1337
- ValueError : If waiting and the training job fails.
1350
+ exceptions.UnexpectedStatusException : If waiting and the training job fails.
1338
1351
"""
1339
1352
1340
1353
description = self .sagemaker_client .describe_training_job (TrainingJobName = job_name )
0 commit comments