34
34
secondary_training_status_changed ,
35
35
secondary_training_status_message ,
36
36
)
37
+ from sagemaker import exceptions
37
38
38
39
LOGGER = logging .getLogger ("sagemaker" )
39
40
@@ -206,7 +207,7 @@ def default_bucket(self):
206
207
Bucket = default_bucket , CreateBucketConfiguration = {"LocationConstraint" : region }
207
208
)
208
209
209
- LOGGER .info ("Created S3 bucket: %s" , default_bucket )
210
+ LOGGER .info ("Created S3 bucket: {}" . format ( default_bucket ) )
210
211
except ClientError as e :
211
212
error_code = e .response ["Error" ]["Code" ]
212
213
message = e .response ["Error" ]["Message" ]
@@ -343,8 +344,8 @@ def train( # noqa: C901
343
344
if encrypt_inter_container_traffic :
344
345
train_request ["EnableInterContainerTrafficEncryption" ] = encrypt_inter_container_traffic
345
346
346
- LOGGER .info ("Creating training-job with name: %s" , job_name )
347
- LOGGER .debug ("train request: %s" , json .dumps (train_request , indent = 4 ))
347
+ LOGGER .info ("Creating training-job with name: {}" . format ( job_name ) )
348
+ LOGGER .debug ("train request: {}" . format ( json .dumps (train_request , indent = 4 ) ))
348
349
self .sagemaker_client .create_training_job (** train_request )
349
350
350
351
def compile_model (
@@ -379,7 +380,7 @@ def compile_model(
379
380
if tags is not None :
380
381
compilation_job_request ["Tags" ] = tags
381
382
382
- LOGGER .info ("Creating compilation-job with name: %s" , job_name )
383
+ LOGGER .info ("Creating compilation-job with name: {}" . format ( job_name ) )
383
384
self .sagemaker_client .create_compilation_job (** compilation_job_request )
384
385
385
386
def tune (
@@ -521,8 +522,8 @@ def tune(
521
522
if encrypt_inter_container_traffic :
522
523
tune_request ["TrainingJobDefinition" ]["EnableInterContainerTrafficEncryption" ] = True
523
524
524
- LOGGER .info ("Creating hyperparameter tuning job with name: %s" , job_name )
525
- LOGGER .debug ("tune request: %s" , json .dumps (tune_request , indent = 4 ))
525
+ LOGGER .info ("Creating hyperparameter tuning job with name: {}" . format ( job_name ) )
526
+ LOGGER .debug ("tune request: {}" . format ( json .dumps (tune_request , indent = 4 ) ))
526
527
self .sagemaker_client .create_hyper_parameter_tuning_job (** tune_request )
527
528
528
529
def stop_tuning_job (self , name ):
@@ -535,17 +536,18 @@ def stop_tuning_job(self, name):
535
536
ClientError: If an error occurs while trying to stop the hyperparameter tuning job.
536
537
"""
537
538
try :
538
- LOGGER .info ("Stopping tuning job: %s" , name )
539
+ LOGGER .info ("Stopping tuning job: {}" . format ( name ) )
539
540
self .sagemaker_client .stop_hyper_parameter_tuning_job (HyperParameterTuningJobName = name )
540
541
except ClientError as e :
541
542
error_code = e .response ["Error" ]["Code" ]
542
543
# allow to pass if the job already stopped
543
544
if error_code == "ValidationException" :
544
- LOGGER .info ("Tuning job: %s is already stopped or not running." , name )
545
+ LOGGER .info ("Tuning job: {} is already stopped or not running." . format ( name ) )
545
546
else :
546
547
LOGGER .error (
547
- "Error occurred while attempting to stop tuning job: %s. Please try again." ,
548
- name ,
548
+ "Error occurred while attempting to stop tuning job: {}. Please try again." .format (
549
+ name
550
+ )
549
551
)
550
552
raise
551
553
@@ -607,8 +609,8 @@ def transform(
607
609
if data_processing is not None :
608
610
transform_request ["DataProcessing" ] = data_processing
609
611
610
- LOGGER .info ("Creating transform job with name: %s" , job_name )
611
- LOGGER .debug ("Transform request: %s" , json .dumps (transform_request , indent = 4 ))
612
+ LOGGER .info ("Creating transform job with name: {}" . format ( job_name ) )
613
+ LOGGER .debug ("Transform request: {}" . format ( json .dumps (transform_request , indent = 4 ) ))
612
614
self .sagemaker_client .create_transform_job (** transform_request )
613
615
614
616
def create_model (
@@ -680,8 +682,8 @@ def create_model(
680
682
if enable_network_isolation :
681
683
create_model_request ["EnableNetworkIsolation" ] = True
682
684
683
- LOGGER .info ("Creating model with name: %s" , name )
684
- LOGGER .debug ("CreateModel request: %s" , json .dumps (create_model_request , indent = 4 ))
685
+ LOGGER .info ("Creating model with name: {}" . format ( name ) )
686
+ LOGGER .debug ("CreateModel request: {}" . format ( json .dumps (create_model_request , indent = 4 ) ))
685
687
686
688
try :
687
689
self .sagemaker_client .create_model (** create_model_request )
@@ -693,7 +695,7 @@ def create_model(
693
695
error_code == "ValidationException"
694
696
and "Cannot create already existing model" in message
695
697
):
696
- LOGGER .warning ("Using already existing model: %s" , name )
698
+ LOGGER .warning ("Using already existing model: {}" . format ( name ) )
697
699
else :
698
700
raise
699
701
@@ -764,14 +766,14 @@ def create_model_package_from_algorithm(self, name, description, algorithm_arn,
764
766
},
765
767
}
766
768
try :
767
- LOGGER .info ("Creating model package with name: %s" , name )
769
+ LOGGER .info ("Creating model package with name: {}" . format ( name ) )
768
770
self .sagemaker_client .create_model_package (** request )
769
771
except ClientError as e :
770
772
error_code = e .response ["Error" ]["Code" ]
771
773
message = e .response ["Error" ]["Message" ]
772
774
773
775
if error_code == "ValidationException" and "ModelPackage already exists" in message :
774
- LOGGER .warning ("Using already existing model package: %s" , name )
776
+ LOGGER .warning ("Using already existing model package: {}" . format ( name ) )
775
777
else :
776
778
raise
777
779
@@ -792,10 +794,10 @@ def wait_for_model_package(self, model_package_name, poll=5):
792
794
793
795
if status != "Completed" :
794
796
reason = desc .get ("FailureReason" , None )
795
- raise ValueError (
796
- "Error creating model package {}: {} Reason: {}" .format (
797
- model_package_name , status , reason
798
- )
797
+ raise exceptions . UnexpectedStatusException (
798
+ message = "Error creating model package {}: {} Reason: {}" .format (model_package_name , status , reason ),
799
+ allowed_statuses = [ "Completed" ],
800
+ actual_status = status
799
801
)
800
802
return desc
801
803
@@ -832,7 +834,7 @@ def create_endpoint_config(
832
834
Returns:
833
835
str: Name of the endpoint point configuration created.
834
836
"""
835
- LOGGER .info ("Creating endpoint-config with name %s" , name )
837
+ LOGGER .info ("Creating endpoint-config with name {}" . format ( name ) )
836
838
837
839
tags = tags or []
838
840
@@ -871,7 +873,7 @@ def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True):
871
873
Returns:
872
874
str: Name of the Amazon SageMaker ``Endpoint`` created.
873
875
"""
874
- LOGGER .info ("Creating endpoint with name %s" , endpoint_name )
876
+ LOGGER .info ("Creating endpoint with name {}" . format ( endpoint_name ) )
875
877
876
878
tags = tags or []
877
879
@@ -914,7 +916,7 @@ def delete_endpoint(self, endpoint_name):
914
916
Args:
915
917
endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to delete.
916
918
"""
917
- LOGGER .info ("Deleting endpoint with name: %s" , endpoint_name )
919
+ LOGGER .info ("Deleting endpoint with name: {}" . format ( endpoint_name ) )
918
920
self .sagemaker_client .delete_endpoint (EndpointName = endpoint_name )
919
921
920
922
def delete_endpoint_config (self , endpoint_config_name ):
@@ -923,7 +925,7 @@ def delete_endpoint_config(self, endpoint_config_name):
923
925
Args:
924
926
endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to delete.
925
927
"""
926
- LOGGER .info ("Deleting endpoint configuration with name: %s" , endpoint_config_name )
928
+ LOGGER .info ("Deleting endpoint configuration with name: {}" . format ( endpoint_config_name ) )
927
929
self .sagemaker_client .delete_endpoint_config (EndpointConfigName = endpoint_config_name )
928
930
929
931
def delete_model (self , model_name ):
@@ -933,7 +935,7 @@ def delete_model(self, model_name):
933
935
model_name (str): Name of the Amazon SageMaker model to delete.
934
936
935
937
"""
936
- LOGGER .info ("Deleting model with name: %s" , model_name )
938
+ LOGGER .info ("Deleting model with name: {}" . format ( model_name ) )
937
939
self .sagemaker_client .delete_model (ModelName = model_name )
938
940
939
941
def wait_for_job (self , job , poll = 5 ):
@@ -947,7 +949,7 @@ def wait_for_job(self, job, poll=5):
947
949
(dict): Return value from the ``DescribeTrainingJob`` API.
948
950
949
951
Raises:
950
- ValueError : If the training job fails.
952
+ exceptions.UnexpectedStatusException : If the training job fails.
951
953
"""
952
954
desc = _wait_until_training_done (
953
955
lambda last_desc : _train_done (self .sagemaker_client , job , last_desc ), None , poll
@@ -966,7 +968,7 @@ def wait_for_compilation_job(self, job, poll=5):
966
968
(dict): Return value from the ``DescribeCompilationJob`` API.
967
969
968
970
Raises:
969
- ValueError : If the compilation job fails.
971
+ exceptions.UnexpectedStatusException : If the compilation job fails.
970
972
"""
971
973
desc = _wait_until (lambda : _compilation_job_status (self .sagemaker_client , job ), poll )
972
974
self ._check_job_status (job , desc , "CompilationJobStatus" )
@@ -983,7 +985,7 @@ def wait_for_tuning_job(self, job, poll=5):
983
985
(dict): Return value from the ``DescribeHyperParameterTuningJob`` API.
984
986
985
987
Raises:
986
- ValueError : If the hyperparameter tuning job fails.
988
+ exceptions.UnexpectedStatusException : If the hyperparameter tuning job fails.
987
989
"""
988
990
desc = _wait_until (lambda : _tuning_job_status (self .sagemaker_client , job ), poll )
989
991
self ._check_job_status (job , desc , "HyperParameterTuningJobStatus" )
@@ -1000,23 +1002,23 @@ def wait_for_transform_job(self, job, poll=5):
1000
1002
(dict): Return value from the ``DescribeTransformJob`` API.
1001
1003
1002
1004
Raises:
1003
- ValueError : If the transform job fails.
1005
+ exceptions.UnexpectedStatusException : If the transform job fails.
1004
1006
"""
1005
1007
desc = _wait_until (lambda : _transform_job_status (self .sagemaker_client , job ), poll )
1006
1008
self ._check_job_status (job , desc , "TransformJobStatus" )
1007
1009
return desc
1008
1010
1009
1011
def _check_job_status (self , job , desc , status_key_name ):
1010
1012
"""Check to see if the job completed successfully and, if not, construct and
1011
- raise a ValueError .
1013
+ raise a exceptions.UnexpectedStatusException .
1012
1014
1013
1015
Args:
1014
1016
job (str): The name of the job to check.
1015
1017
desc (dict[str, str]): The result of ``describe_training_job()``.
1016
1018
status_key_name (str): Status key name to check for.
1017
1019
1018
1020
Raises:
1019
- ValueError : If the training job fails.
1021
+ exceptions.UnexpectedStatusException : If the training job fails.
1020
1022
"""
1021
1023
status = desc [status_key_name ]
1022
1024
# If the status is capital case, then convert it to Camel case
@@ -1025,7 +1027,11 @@ def _check_job_status(self, job, desc, status_key_name):
1025
1027
if status != "Completed" and status != "Stopped" :
1026
1028
reason = desc .get ("FailureReason" , "(No reason provided)" )
1027
1029
job_type = status_key_name .replace ("JobStatus" , " job" )
1028
- raise ValueError ("Error for {} {}: {} Reason: {}" .format (job_type , job , status , reason ))
1030
+ raise exceptions .UnexpectedStatusException (
1031
+ message = "Error for {} {}: {} Reason: {}" .format (job_type , job , status , reason ),
1032
+ allowed_statuses = ["Completed" , "Stopped" ],
1033
+ actual_status = status
1034
+ )
1029
1035
1030
1036
def wait_for_endpoint (self , endpoint , poll = 5 ):
1031
1037
"""Wait for an Amazon SageMaker endpoint deployment to complete.
@@ -1042,8 +1048,10 @@ def wait_for_endpoint(self, endpoint, poll=5):
1042
1048
1043
1049
if status != "InService" :
1044
1050
reason = desc .get ("FailureReason" , None )
1045
- raise ValueError (
1046
- "Error hosting endpoint {}: {} Reason: {}" .format (endpoint , status , reason )
1051
+ raise exceptions .UnexpectedStatusException (
1052
+ message = "Error hosting endpoint {}: {} Reason: {}" .format (endpoint , status , reason ),
1053
+ allowed_statuses = ["InService" ],
1054
+ actual_status = status
1047
1055
)
1048
1056
return desc
1049
1057
@@ -1257,8 +1265,9 @@ def get_caller_identity_arn(self):
1257
1265
role = self .boto_session .client ("iam" ).get_role (RoleName = role_name )["Role" ]["Arn" ]
1258
1266
except ClientError :
1259
1267
LOGGER .warning (
1260
- "Couldn't call 'get_role' to get Role ARN from role name %s to get Role path." ,
1261
- role_name ,
1268
+ "Couldn't call 'get_role' to get Role ARN from role name {} to get Role path." .format (
1269
+ role_name
1270
+ )
1262
1271
)
1263
1272
1264
1273
return role
@@ -1276,7 +1285,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
1276
1285
poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
1277
1286
1278
1287
Raises:
1279
- ValueError : If waiting and the training job fails.
1288
+ exceptions.UnexpectedStatusException : If waiting and the training job fails.
1280
1289
"""
1281
1290
1282
1291
description = self .sagemaker_client .describe_training_job (TrainingJobName = job_name )
0 commit comments