Skip to content

Commit 81dfee2

Browse files
author
Payton Staub
committed
Interim commit #3 - support all processors for ProcessingStep
1 parent 229ead5 commit 81dfee2

File tree

7 files changed

+799
-35
lines changed

7 files changed

+799
-35
lines changed

src/sagemaker/processing.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -463,24 +463,29 @@ def __init__(
463463

464464
def get_run_args(
465465
self,
466-
code,
466+
code=None,
467467
inputs=None,
468468
outputs=None,
469469
arguments=None,
470-
job_name=None,
471-
kms_key=None,
472470
):
473-
# TODO: description
474-
normalized_inputs, normalized_outputs = self._normalize_args(
475-
job_name=job_name,
476-
arguments=arguments,
477-
inputs=inputs,
478-
outputs=outputs,
479-
code=code,
480-
kms_key=kms_key,
481-
)
471+
"""Returns a RunArgs object. For processors (:class:`~sagemaker.spark.processing.PySparkProcessor`,
472+
:class:`~sagemaker.spark.processing.SparkJar`) that have special
473+
run() arguments, this object contains the normalized arguments for passing to
474+
:class:`~sagemaker.workflow.steps.ProcessingStep`.
482475
483-
return RunArgs(inputs=normalized_inputs, outputs=normalized_outputs, code=code)
476+
Args:
477+
code (str): This can be an S3 URI or a local path to a file with the framework
478+
script to run.
479+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
480+
the processing job. These must be provided as
481+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
482+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
483+
the processing job. These can be specified as either path strings or
484+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
485+
arguments (list[str]): A list of string arguments to be passed to a
486+
processing job (default: None).
487+
"""
488+
return RunArgs(code=code, inputs=inputs, outputs=outputs, arguments=arguments)
484489

485490
def run(
486491
self,
@@ -1195,6 +1200,7 @@ def __init__(
11951200
inputs=None,
11961201
outputs=None,
11971202
code=None,
1203+
arguments=None,
11981204
):
11991205
"""Initializes a ``ProcessingOutput`` instance.
12001206
@@ -1216,6 +1222,7 @@ def __init__(
12161222
self.inputs = inputs
12171223
self.outputs = outputs
12181224
self.code = code
1225+
self.arguments = arguments
12191226

12201227

12211228
class FeatureStoreOutput(ApiObject):

src/sagemaker/spark/processing.py

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -173,21 +173,33 @@ def __init__(
173173

174174
def get_run_args(
175175
self,
176-
submit_app,
176+
code=None,
177177
inputs=None,
178178
outputs=None,
179179
arguments=None,
180-
job_name=None,
181-
kms_key=None,
182180
):
183-
# TODO: description
181+
"""Returns a RunArgs object. For processors (:class:`~sagemaker.spark.processing.PySparkProcessor`,
182+
:class:`~sagemaker.spark.processing.SparkJar`) that have special
183+
run() arguments, this object contains the normalized arguments for passing to
184+
:class:`~sagemaker.workflow.steps.ProcessingStep`.
185+
186+
Args:
187+
code (str): This can be an S3 URI or a local path to a file with the framework
188+
script to run.
189+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
190+
the processing job. These must be provided as
191+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
192+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
193+
the processing job. These can be specified as either path strings or
194+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
195+
arguments (list[str]): A list of string arguments to be passed to a
196+
processing job (default: None).
197+
"""
184198
return super().get_run_args(
185-
code=submit_app,
199+
code=code,
186200
inputs=inputs,
187201
outputs=outputs,
188202
arguments=arguments,
189-
job_name=job_name,
190-
kms_key=kms_key,
191203
)
192204

193205
def run(
@@ -716,8 +728,35 @@ def get_run_args(
716728
job_name=None,
717729
configuration=None,
718730
spark_event_logs_s3_uri=None,
719-
kms_key=None,
720731
):
732+
"""Returns a RunArgs object. This object contains the normalized inputs, outputs
733+
and arguments needed when creating using a ``PySparkProcessor`` in a :class:`~sagemaker.workflow.steps.ProcessingStep`.
734+
735+
Args:
736+
submit_app (str): Path (local or S3) to Python file to submit to Spark
737+
as the primary application
738+
submit_py_files (list[str]): List of paths (local or S3) to provide for
739+
`spark-submit --py-files` option
740+
submit_jars (list[str]): List of paths (local or S3) to provide for
741+
`spark-submit --jars` option
742+
submit_files (list[str]): List of paths (local or S3) to provide for
743+
`spark-submit --files` option
744+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
745+
the processing job. These must be provided as
746+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
747+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
748+
the processing job. These can be specified as either path strings or
749+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
750+
arguments (list[str]): A list of string arguments to be passed to a
751+
processing job (default: None).
752+
job_name (str): Processing job name. If not specified, the processor generates
753+
a default job name, based on the base job name and current timestamp.
754+
configuration (list[dict] or dict): Configuration for Hadoop, Spark, or Hive.
755+
List or dictionary of EMR-style classifications.
756+
https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
757+
spark_event_logs_s3_uri (str): S3 path where spark application events will
758+
be published to.
759+
"""
721760
self._current_job_name = self._generate_current_job_name(job_name=job_name)
722761
self.command = [_SparkProcessorBase._default_command]
723762

@@ -734,14 +773,11 @@ def get_run_args(
734773
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
735774
)
736775

737-
# TODO: description
738776
return super().get_run_args(
739-
submit_app=submit_app,
777+
code=submit_app,
740778
inputs=extended_inputs,
741779
outputs=extended_outputs,
742780
arguments=arguments,
743-
job_name=self._current_job_name,
744-
kms_key=kms_key,
745781
)
746782

747783
def run(
@@ -821,6 +857,7 @@ def run(
821857
logs=logs,
822858
job_name=self._current_job_name,
823859
experiment_config=experiment_config,
860+
kms_key=kms_key,
824861
)
825862

826863
def _extend_processing_args(self, inputs, outputs, **kwargs):
@@ -937,8 +974,35 @@ def get_run_args(
937974
job_name=None,
938975
configuration=None,
939976
spark_event_logs_s3_uri=None,
940-
kms_key=None,
941977
):
978+
"""Returns a RunArgs object. This object contains the normalized inputs, outputs
979+
and arguments needed when creating using a ``SparkJarProcessor`` in a :class:`~sagemaker.workflow.steps.ProcessingStep`.
980+
981+
Args:
982+
submit_app (str): Path (local or S3) to Python file to submit to Spark
983+
as the primary application
984+
submit_class (str): Java class reference to submit to Spark as the primary
985+
application
986+
submit_jars (list[str]): List of paths (local or S3) to provide for
987+
`spark-submit --jars` option
988+
submit_files (list[str]): List of paths (local or S3) to provide for
989+
`spark-submit --files` option
990+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
991+
the processing job. These must be provided as
992+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
993+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
994+
the processing job. These can be specified as either path strings or
995+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
996+
arguments (list[str]): A list of string arguments to be passed to a
997+
processing job (default: None).
998+
job_name (str): Processing job name. If not specified, the processor generates
999+
a default job name, based on the base job name and current timestamp.
1000+
configuration (list[dict] or dict): Configuration for Hadoop, Spark, or Hive.
1001+
List or dictionary of EMR-style classifications.
1002+
https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
1003+
spark_event_logs_s3_uri (str): S3 path where spark application events will
1004+
be published to.
1005+
"""
9421006
self._current_job_name = self._generate_current_job_name(job_name=job_name)
9431007
self.command = [_SparkProcessorBase._default_command]
9441008

@@ -955,14 +1019,11 @@ def get_run_args(
9551019
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
9561020
)
9571021

958-
# TODO: description
9591022
return super().get_run_args(
960-
submit_app=submit_app,
1023+
code=submit_app,
9611024
inputs=extended_inputs,
9621025
outputs=extended_outputs,
9631026
arguments=arguments,
964-
job_name=self._current_job_name,
965-
kms_key=kms_key,
9661027
)
9671028

9681029
def run(

src/sagemaker/workflow/steps.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def __init__(
320320
code: str = None,
321321
property_files: List[PropertyFile] = None,
322322
cache_config: CacheConfig = None,
323+
kms_key=None,
323324
):
324325
"""Construct a ProcessingStep, given a `Processor` instance.
325326
@@ -340,6 +341,8 @@ def __init__(
340341
property_files (List[PropertyFile]): A list of property files that workflow looks
341342
for and resolves from the configured processing output list.
342343
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
344+
kms_key (str): The ARN of the KMS key that is used to encrypt the
345+
user code file (default: None)
343346
"""
344347
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING)
345348
self.processor = processor
@@ -348,6 +351,7 @@ def __init__(
348351
self.job_arguments = job_arguments
349352
self.code = code
350353
self.property_files = property_files
354+
self.kms_key = kms_key
351355

352356
# Examine why run method in sagemaker.processing.Processor mutates the processor instance
353357
# by setting the instance's arguments attribute. Refactor Processor.run, if possible.
@@ -370,7 +374,9 @@ def arguments(self) -> RequestType:
370374
inputs=self.inputs,
371375
outputs=self.outputs,
372376
code=self.code,
377+
kms_key=self.kms_key,
373378
)
379+
374380
process_args = ProcessingJob._get_process_args(
375381
self.processor, normalized_inputs, normalized_outputs, experiment_config=dict()
376382
)

0 commit comments

Comments
 (0)