30
30
from sagemaker .local import LocalSession
31
31
from sagemaker .utils import base_name_from_image , name_from_base
32
32
from sagemaker .session import Session
33
- from sagemaker .network import NetworkConfig # noqa: F401 # pylint: disable=unused-import
33
+ from sagemaker .network import (
34
+ NetworkConfig ,
35
+ ) # noqa: F401 # pylint: disable=unused-import
34
36
from sagemaker .workflow .properties import Properties
35
37
from sagemaker .workflow .parameters import Parameter
36
38
from sagemaker .workflow .entities import Expression
@@ -185,7 +187,9 @@ def run(
185
187
if wait :
186
188
self .latest_job .wait (logs = logs )
187
189
188
- def _extend_processing_args (self , inputs , outputs , ** kwargs ): # pylint: disable=W0613
190
+ def _extend_processing_args (
191
+ self , inputs , outputs , ** kwargs
192
+ ): # pylint: disable=W0613
189
193
"""Extend inputs and outputs based on extra parameters"""
190
194
return inputs , outputs
191
195
@@ -287,15 +291,22 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
287
291
# Iterate through the provided list of inputs.
288
292
for count , file_input in enumerate (inputs , 1 ):
289
293
if not isinstance (file_input , ProcessingInput ):
290
- raise TypeError ("Your inputs must be provided as ProcessingInput objects." )
294
+ raise TypeError (
295
+ "Your inputs must be provided as ProcessingInput objects."
296
+ )
291
297
# Generate a name for the ProcessingInput if it doesn't have one.
292
298
if file_input .input_name is None :
293
299
file_input .input_name = "input-{}" .format (count )
294
300
295
- if isinstance (file_input .source , Properties ) or file_input .dataset_definition :
301
+ if (
302
+ isinstance (file_input .source , Properties )
303
+ or file_input .dataset_definition
304
+ ):
296
305
normalized_inputs .append (file_input )
297
306
continue
298
- if isinstance (file_input .s3_input .s3_uri , (Parameter , Expression , Properties )):
307
+ if isinstance (
308
+ file_input .s3_input .s3_uri , (Parameter , Expression , Properties )
309
+ ):
299
310
normalized_inputs .append (file_input )
300
311
continue
301
312
# If the source is a local path, upload it to S3
@@ -341,7 +352,9 @@ def _normalize_outputs(self, outputs=None):
341
352
# Iterate through the provided list of outputs.
342
353
for count , output in enumerate (outputs , 1 ):
343
354
if not isinstance (output , ProcessingOutput ):
344
- raise TypeError ("Your outputs must be provided as ProcessingOutput objects." )
355
+ raise TypeError (
356
+ "Your outputs must be provided as ProcessingOutput objects."
357
+ )
345
358
# Generate a name for the ProcessingOutput if it doesn't have one.
346
359
if output .output_name is None :
347
360
output .output_name = "output-{}" .format (count )
@@ -553,7 +566,9 @@ def _include_code_in_inputs(self, inputs, code, kms_key=None):
553
566
user_code_s3_uri = self ._handle_user_code_url (code , kms_key )
554
567
user_script_name = self ._get_user_code_name (code )
555
568
556
- inputs_with_code = self ._convert_code_and_add_to_inputs (inputs , user_code_s3_uri )
569
+ inputs_with_code = self ._convert_code_and_add_to_inputs (
570
+ inputs , user_code_s3_uri
571
+ )
557
572
558
573
self ._set_entrypoint (self .command , user_script_name )
559
574
return inputs_with_code
@@ -641,7 +656,7 @@ def _upload_code(self, code, kms_key=None):
641
656
local_path = code ,
642
657
desired_s3_uri = desired_s3_uri ,
643
658
kms_key = kms_key ,
644
- sagemaker_session = self .sagemaker_session
659
+ sagemaker_session = self .sagemaker_session ,
645
660
)
646
661
647
662
def _convert_code_and_add_to_inputs (self , inputs , s3_uri ):
@@ -677,7 +692,9 @@ def _set_entrypoint(self, command, user_script_name):
677
692
"""
678
693
user_script_location = str (
679
694
pathlib .PurePosixPath (
680
- self ._CODE_CONTAINER_BASE_PATH , self ._CODE_CONTAINER_INPUT_NAME , user_script_name
695
+ self ._CODE_CONTAINER_BASE_PATH ,
696
+ self ._CODE_CONTAINER_INPUT_NAME ,
697
+ user_script_name ,
681
698
)
682
699
)
683
700
self .entrypoint = command + [user_script_location ]
@@ -686,7 +703,9 @@ def _set_entrypoint(self, command, user_script_name):
686
703
class ProcessingJob (_Job ):
687
704
"""Provides functionality to start, describe, and stop processing jobs."""
688
705
689
- def __init__ (self , sagemaker_session , job_name , inputs , outputs , output_kms_key = None ):
706
+ def __init__ (
707
+ self , sagemaker_session , job_name , inputs , outputs , output_kms_key = None
708
+ ):
690
709
"""Initializes a Processing job.
691
710
692
711
Args:
@@ -704,7 +723,9 @@ def __init__(self, sagemaker_session, job_name, inputs, outputs, output_kms_key=
704
723
self .inputs = inputs
705
724
self .outputs = outputs
706
725
self .output_kms_key = output_kms_key
707
- super (ProcessingJob , self ).__init__ (sagemaker_session = sagemaker_session , job_name = job_name )
726
+ super (ProcessingJob , self ).__init__ (
727
+ sagemaker_session = sagemaker_session , job_name = job_name
728
+ )
708
729
709
730
@classmethod
710
731
def start_new (cls , processor , inputs , outputs , experiment_config ):
@@ -725,7 +746,9 @@ def start_new(cls, processor, inputs, outputs, experiment_config):
725
746
:class:`~sagemaker.processing.ProcessingJob`: The instance of ``ProcessingJob`` created
726
747
using the ``Processor``.
727
748
"""
728
- process_args = cls ._get_process_args (processor , inputs , outputs , experiment_config )
749
+ process_args = cls ._get_process_args (
750
+ processor , inputs , outputs , experiment_config
751
+ )
729
752
730
753
# Print the job name and the user's inputs and outputs as lists of dictionaries.
731
754
print ()
@@ -799,18 +822,26 @@ def _get_process_args(cls, processor, inputs, outputs, experiment_config):
799
822
800
823
process_request_args ["app_specification" ] = {"ImageUri" : processor .image_uri }
801
824
if processor .arguments is not None :
802
- process_request_args ["app_specification" ]["ContainerArguments" ] = processor .arguments
825
+ process_request_args ["app_specification" ][
826
+ "ContainerArguments"
827
+ ] = processor .arguments
803
828
if processor .entrypoint is not None :
804
- process_request_args ["app_specification" ]["ContainerEntrypoint" ] = processor .entrypoint
829
+ process_request_args ["app_specification" ][
830
+ "ContainerEntrypoint"
831
+ ] = processor .entrypoint
805
832
806
833
process_request_args ["environment" ] = processor .env
807
834
808
835
if processor .network_config is not None :
809
- process_request_args ["network_config" ] = processor .network_config ._to_request_dict ()
836
+ process_request_args [
837
+ "network_config"
838
+ ] = processor .network_config ._to_request_dict ()
810
839
else :
811
840
process_request_args ["network_config" ] = None
812
841
813
- process_request_args ["role_arn" ] = processor .sagemaker_session .expand_role (processor .role )
842
+ process_request_args ["role_arn" ] = processor .sagemaker_session .expand_role (
843
+ processor .role
844
+ )
814
845
815
846
process_request_args ["tags" ] = processor .tags
816
847
@@ -831,7 +862,9 @@ def from_processing_name(cls, sagemaker_session, processing_job_name):
831
862
:class:`~sagemaker.processing.ProcessingJob`: The instance of ``ProcessingJob`` created
832
863
from the job name.
833
864
"""
834
- job_desc = sagemaker_session .describe_processing_job (job_name = processing_job_name )
865
+ job_desc = sagemaker_session .describe_processing_job (
866
+ job_name = processing_job_name
867
+ )
835
868
836
869
inputs = None
837
870
if job_desc .get ("ProcessingInputs" ):
@@ -848,9 +881,9 @@ def from_processing_name(cls, sagemaker_session, processing_job_name):
848
881
]
849
882
850
883
outputs = None
851
- if job_desc .get ("ProcessingOutputConfig" ) and job_desc ["ProcessingOutputConfig" ]. get (
852
- "Outputs "
853
- ):
884
+ if job_desc .get ("ProcessingOutputConfig" ) and job_desc [
885
+ "ProcessingOutputConfig "
886
+ ]. get ( "Outputs" ):
854
887
outputs = []
855
888
for processing_output_dict in job_desc ["ProcessingOutputConfig" ]["Outputs" ]:
856
889
processing_output = ProcessingOutput (
@@ -862,8 +895,12 @@ def from_processing_name(cls, sagemaker_session, processing_job_name):
862
895
)
863
896
864
897
if "S3Output" in processing_output_dict :
865
- processing_output .source = processing_output_dict ["S3Output" ]["LocalPath" ]
866
- processing_output .destination = processing_output_dict ["S3Output" ]["S3Uri" ]
898
+ processing_output .source = processing_output_dict ["S3Output" ][
899
+ "LocalPath"
900
+ ]
901
+ processing_output .destination = processing_output_dict ["S3Output" ][
902
+ "S3Uri"
903
+ ]
867
904
868
905
outputs .append (processing_output )
869
906
output_kms_key = None
@@ -1077,15 +1114,20 @@ def _to_request_dict(self):
1077
1114
"""Generates a request dictionary using the parameters provided to the class."""
1078
1115
1079
1116
# Create the request dictionary.
1080
- s3_input_request = {"InputName" : self .input_name , "AppManaged" : self .app_managed }
1117
+ s3_input_request = {
1118
+ "InputName" : self .input_name ,
1119
+ "AppManaged" : self .app_managed ,
1120
+ }
1081
1121
1082
1122
if self .s3_input :
1083
1123
# Check the compression type, then add it to the dictionary.
1084
1124
if (
1085
1125
self .s3_input .s3_compression_type == "Gzip"
1086
1126
and self .s3_input .s3_input_mode != "Pipe"
1087
1127
):
1088
- raise ValueError ("Data can only be gzipped when the input mode is Pipe." )
1128
+ raise ValueError (
1129
+ "Data can only be gzipped when the input mode is Pipe."
1130
+ )
1089
1131
1090
1132
s3_input_request ["S3Input" ] = S3Input .to_boto (self .s3_input )
1091
1133
0 commit comments