Skip to content

Commit 77c9051

Browse files
author
Quentin R. Voglund
committed
black-check update
1 parent 5e31ea0 commit 77c9051

File tree

1 file changed

+66
-24
lines changed

1 file changed

+66
-24
lines changed

src/sagemaker/processing.py

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
from sagemaker.local import LocalSession
3131
from sagemaker.utils import base_name_from_image, name_from_base
3232
from sagemaker.session import Session
33-
from sagemaker.network import NetworkConfig # noqa: F401 # pylint: disable=unused-import
33+
from sagemaker.network import (
34+
NetworkConfig,
35+
) # noqa: F401 # pylint: disable=unused-import
3436
from sagemaker.workflow.properties import Properties
3537
from sagemaker.workflow.parameters import Parameter
3638
from sagemaker.workflow.entities import Expression
@@ -185,7 +187,9 @@ def run(
185187
if wait:
186188
self.latest_job.wait(logs=logs)
187189

188-
def _extend_processing_args(self, inputs, outputs, **kwargs): # pylint: disable=W0613
190+
def _extend_processing_args(
191+
self, inputs, outputs, **kwargs
192+
): # pylint: disable=W0613
189193
"""Extend inputs and outputs based on extra parameters"""
190194
return inputs, outputs
191195

@@ -287,15 +291,22 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
287291
# Iterate through the provided list of inputs.
288292
for count, file_input in enumerate(inputs, 1):
289293
if not isinstance(file_input, ProcessingInput):
290-
raise TypeError("Your inputs must be provided as ProcessingInput objects.")
294+
raise TypeError(
295+
"Your inputs must be provided as ProcessingInput objects."
296+
)
291297
# Generate a name for the ProcessingInput if it doesn't have one.
292298
if file_input.input_name is None:
293299
file_input.input_name = "input-{}".format(count)
294300

295-
if isinstance(file_input.source, Properties) or file_input.dataset_definition:
301+
if (
302+
isinstance(file_input.source, Properties)
303+
or file_input.dataset_definition
304+
):
296305
normalized_inputs.append(file_input)
297306
continue
298-
if isinstance(file_input.s3_input.s3_uri, (Parameter, Expression, Properties)):
307+
if isinstance(
308+
file_input.s3_input.s3_uri, (Parameter, Expression, Properties)
309+
):
299310
normalized_inputs.append(file_input)
300311
continue
301312
# If the source is a local path, upload it to S3
@@ -341,7 +352,9 @@ def _normalize_outputs(self, outputs=None):
341352
# Iterate through the provided list of outputs.
342353
for count, output in enumerate(outputs, 1):
343354
if not isinstance(output, ProcessingOutput):
344-
raise TypeError("Your outputs must be provided as ProcessingOutput objects.")
355+
raise TypeError(
356+
"Your outputs must be provided as ProcessingOutput objects."
357+
)
345358
# Generate a name for the ProcessingOutput if it doesn't have one.
346359
if output.output_name is None:
347360
output.output_name = "output-{}".format(count)
@@ -553,7 +566,9 @@ def _include_code_in_inputs(self, inputs, code, kms_key=None):
553566
user_code_s3_uri = self._handle_user_code_url(code, kms_key)
554567
user_script_name = self._get_user_code_name(code)
555568

556-
inputs_with_code = self._convert_code_and_add_to_inputs(inputs, user_code_s3_uri)
569+
inputs_with_code = self._convert_code_and_add_to_inputs(
570+
inputs, user_code_s3_uri
571+
)
557572

558573
self._set_entrypoint(self.command, user_script_name)
559574
return inputs_with_code
@@ -641,7 +656,7 @@ def _upload_code(self, code, kms_key=None):
641656
local_path=code,
642657
desired_s3_uri=desired_s3_uri,
643658
kms_key=kms_key,
644-
sagemaker_session=self.sagemaker_session
659+
sagemaker_session=self.sagemaker_session,
645660
)
646661

647662
def _convert_code_and_add_to_inputs(self, inputs, s3_uri):
@@ -677,7 +692,9 @@ def _set_entrypoint(self, command, user_script_name):
677692
"""
678693
user_script_location = str(
679694
pathlib.PurePosixPath(
680-
self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME, user_script_name
695+
self._CODE_CONTAINER_BASE_PATH,
696+
self._CODE_CONTAINER_INPUT_NAME,
697+
user_script_name,
681698
)
682699
)
683700
self.entrypoint = command + [user_script_location]
@@ -686,7 +703,9 @@ def _set_entrypoint(self, command, user_script_name):
686703
class ProcessingJob(_Job):
687704
"""Provides functionality to start, describe, and stop processing jobs."""
688705

689-
def __init__(self, sagemaker_session, job_name, inputs, outputs, output_kms_key=None):
706+
def __init__(
707+
self, sagemaker_session, job_name, inputs, outputs, output_kms_key=None
708+
):
690709
"""Initializes a Processing job.
691710
692711
Args:
@@ -704,7 +723,9 @@ def __init__(self, sagemaker_session, job_name, inputs, outputs, output_kms_key=
704723
self.inputs = inputs
705724
self.outputs = outputs
706725
self.output_kms_key = output_kms_key
707-
super(ProcessingJob, self).__init__(sagemaker_session=sagemaker_session, job_name=job_name)
726+
super(ProcessingJob, self).__init__(
727+
sagemaker_session=sagemaker_session, job_name=job_name
728+
)
708729

709730
@classmethod
710731
def start_new(cls, processor, inputs, outputs, experiment_config):
@@ -725,7 +746,9 @@ def start_new(cls, processor, inputs, outputs, experiment_config):
725746
:class:`~sagemaker.processing.ProcessingJob`: The instance of ``ProcessingJob`` created
726747
using the ``Processor``.
727748
"""
728-
process_args = cls._get_process_args(processor, inputs, outputs, experiment_config)
749+
process_args = cls._get_process_args(
750+
processor, inputs, outputs, experiment_config
751+
)
729752

730753
# Print the job name and the user's inputs and outputs as lists of dictionaries.
731754
print()
@@ -799,18 +822,26 @@ def _get_process_args(cls, processor, inputs, outputs, experiment_config):
799822

800823
process_request_args["app_specification"] = {"ImageUri": processor.image_uri}
801824
if processor.arguments is not None:
802-
process_request_args["app_specification"]["ContainerArguments"] = processor.arguments
825+
process_request_args["app_specification"][
826+
"ContainerArguments"
827+
] = processor.arguments
803828
if processor.entrypoint is not None:
804-
process_request_args["app_specification"]["ContainerEntrypoint"] = processor.entrypoint
829+
process_request_args["app_specification"][
830+
"ContainerEntrypoint"
831+
] = processor.entrypoint
805832

806833
process_request_args["environment"] = processor.env
807834

808835
if processor.network_config is not None:
809-
process_request_args["network_config"] = processor.network_config._to_request_dict()
836+
process_request_args[
837+
"network_config"
838+
] = processor.network_config._to_request_dict()
810839
else:
811840
process_request_args["network_config"] = None
812841

813-
process_request_args["role_arn"] = processor.sagemaker_session.expand_role(processor.role)
842+
process_request_args["role_arn"] = processor.sagemaker_session.expand_role(
843+
processor.role
844+
)
814845

815846
process_request_args["tags"] = processor.tags
816847

@@ -831,7 +862,9 @@ def from_processing_name(cls, sagemaker_session, processing_job_name):
831862
:class:`~sagemaker.processing.ProcessingJob`: The instance of ``ProcessingJob`` created
832863
from the job name.
833864
"""
834-
job_desc = sagemaker_session.describe_processing_job(job_name=processing_job_name)
865+
job_desc = sagemaker_session.describe_processing_job(
866+
job_name=processing_job_name
867+
)
835868

836869
inputs = None
837870
if job_desc.get("ProcessingInputs"):
@@ -848,9 +881,9 @@ def from_processing_name(cls, sagemaker_session, processing_job_name):
848881
]
849882

850883
outputs = None
851-
if job_desc.get("ProcessingOutputConfig") and job_desc["ProcessingOutputConfig"].get(
852-
"Outputs"
853-
):
884+
if job_desc.get("ProcessingOutputConfig") and job_desc[
885+
"ProcessingOutputConfig"
886+
].get("Outputs"):
854887
outputs = []
855888
for processing_output_dict in job_desc["ProcessingOutputConfig"]["Outputs"]:
856889
processing_output = ProcessingOutput(
@@ -862,8 +895,12 @@ def from_processing_name(cls, sagemaker_session, processing_job_name):
862895
)
863896

864897
if "S3Output" in processing_output_dict:
865-
processing_output.source = processing_output_dict["S3Output"]["LocalPath"]
866-
processing_output.destination = processing_output_dict["S3Output"]["S3Uri"]
898+
processing_output.source = processing_output_dict["S3Output"][
899+
"LocalPath"
900+
]
901+
processing_output.destination = processing_output_dict["S3Output"][
902+
"S3Uri"
903+
]
867904

868905
outputs.append(processing_output)
869906
output_kms_key = None
@@ -1077,15 +1114,20 @@ def _to_request_dict(self):
10771114
"""Generates a request dictionary using the parameters provided to the class."""
10781115

10791116
# Create the request dictionary.
1080-
s3_input_request = {"InputName": self.input_name, "AppManaged": self.app_managed}
1117+
s3_input_request = {
1118+
"InputName": self.input_name,
1119+
"AppManaged": self.app_managed,
1120+
}
10811121

10821122
if self.s3_input:
10831123
# Check the compression type, then add it to the dictionary.
10841124
if (
10851125
self.s3_input.s3_compression_type == "Gzip"
10861126
and self.s3_input.s3_input_mode != "Pipe"
10871127
):
1088-
raise ValueError("Data can only be gzipped when the input mode is Pipe.")
1128+
raise ValueError(
1129+
"Data can only be gzipped when the input mode is Pipe."
1130+
)
10891131

10901132
s3_input_request["S3Input"] = S3Input.to_boto(self.s3_input)
10911133

0 commit comments

Comments
 (0)