Skip to content

Commit a8ab823

Browse files
staubhpPayton StaubDanahsan-z-khan
authored
feature: Support all processor types in ProcessingStep (#2209)
Co-authored-by: Payton Staub <[email protected]> Co-authored-by: Dan <[email protected]> Co-authored-by: Ahsan Khan <[email protected]>
1 parent fc19b6e commit a8ab823

File tree

7 files changed

+1012
-10
lines changed

7 files changed

+1012
-10
lines changed

src/sagemaker/processing.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import os
2222
import pathlib
23+
import attr
2324

2425
from six.moves.urllib.parse import urlparse
2526
from six.moves.urllib.request import url2pathname
@@ -207,13 +208,13 @@ def _normalize_args(
207208
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
208209
the processing job. These must be provided as
209210
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
210-
kms_key (str): The ARN of the KMS key that is used to encrypt the
211-
user code file (default: None).
212211
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
213212
the processing job. These can be specified as either path strings or
214213
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
215214
code (str): This can be an S3 URI or a local path to a file with the framework
216215
script to run (default: None). A no op in the base class.
216+
kms_key (str): The ARN of the KMS key that is used to encrypt the
217+
user code file (default: None).
217218
"""
218219
self._current_job_name = self._generate_current_job_name(job_name=job_name)
219220

@@ -442,6 +443,34 @@ def __init__(
442443
network_config=network_config,
443444
)
444445

446+
def get_run_args(
447+
self,
448+
code,
449+
inputs=None,
450+
outputs=None,
451+
arguments=None,
452+
):
453+
"""Returns a RunArgs object.
454+
455+
For processors (:class:`~sagemaker.spark.processing.PySparkProcessor`,
456+
:class:`~sagemaker.spark.processing.SparkJar`) that have special
457+
run() arguments, this object contains the normalized arguments for passing to
458+
:class:`~sagemaker.workflow.steps.ProcessingStep`.
459+
460+
Args:
461+
code (str): This can be an S3 URI or a local path to a file with the framework
462+
script to run.
463+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
464+
the processing job. These must be provided as
465+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
466+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
467+
the processing job. These can be specified as either path strings or
468+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
469+
arguments (list[str]): A list of string arguments to be passed to a
470+
processing job (default: None).
471+
"""
472+
return RunArgs(code=code, inputs=inputs, outputs=outputs, arguments=arguments)
473+
445474
def run(
446475
self,
447476
code,
@@ -1144,6 +1173,33 @@ def _to_request_dict(self):
11441173
return s3_output_request
11451174

11461175

1176+
@attr.s
1177+
class RunArgs(object):
1178+
"""Accepts parameters that correspond to ScriptProcessors.
1179+
1180+
An instance of this class is returned from the ``get_run_args()`` method on processors,
1181+
and is used for normalizing the arguments so that they can be passed to
1182+
:class:`~sagemaker.workflow.steps.ProcessingStep`
1183+
1184+
Args:
1185+
code (str): This can be an S3 URI or a local path to a file with the framework
1186+
script to run.
1187+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
1188+
the processing job. These must be provided as
1189+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
1190+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
1191+
the processing job. These can be specified as either path strings or
1192+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
1193+
arguments (list[str]): A list of string arguments to be passed to a
1194+
processing job (default: None).
1195+
"""
1196+
1197+
code = attr.ib()
1198+
inputs = attr.ib(default=None)
1199+
outputs = attr.ib(default=None)
1200+
arguments = attr.ib(default=None)
1201+
1202+
11471203
class FeatureStoreOutput(ApiObject):
11481204
"""Configuration for processing job outputs in Amazon SageMaker Feature Store."""
11491205

src/sagemaker/spark/processing.py

Lines changed: 174 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,39 @@ def __init__(
171171
network_config=network_config,
172172
)
173173

174+
def get_run_args(
175+
self,
176+
code,
177+
inputs=None,
178+
outputs=None,
179+
arguments=None,
180+
):
181+
"""Returns a RunArgs object.
182+
183+
For processors (:class:`~sagemaker.spark.processing.PySparkProcessor`,
184+
:class:`~sagemaker.spark.processing.SparkJar`) that have special
185+
run() arguments, this object contains the normalized arguments for passing to
186+
:class:`~sagemaker.workflow.steps.ProcessingStep`.
187+
188+
Args:
189+
code (str): This can be an S3 URI or a local path to a file with the framework
190+
script to run.
191+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
192+
the processing job. These must be provided as
193+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
194+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
195+
the processing job. These can be specified as either path strings or
196+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
197+
arguments (list[str]): A list of string arguments to be passed to a
198+
processing job (default: None).
199+
"""
200+
return super().get_run_args(
201+
code=code,
202+
inputs=inputs,
203+
outputs=outputs,
204+
arguments=arguments,
205+
)
206+
174207
def run(
175208
self,
176209
submit_app,
@@ -685,6 +718,73 @@ def __init__(
685718
network_config=network_config,
686719
)
687720

721+
def get_run_args(
722+
self,
723+
submit_app,
724+
submit_py_files=None,
725+
submit_jars=None,
726+
submit_files=None,
727+
inputs=None,
728+
outputs=None,
729+
arguments=None,
730+
job_name=None,
731+
configuration=None,
732+
spark_event_logs_s3_uri=None,
733+
):
734+
"""Returns a RunArgs object.
735+
736+
This object contains the normalized inputs, outputs
737+
and arguments needed when using a ``PySparkProcessor``
738+
in a :class:`~sagemaker.workflow.steps.ProcessingStep`.
739+
740+
Args:
741+
submit_app (str): Path (local or S3) to Python file to submit to Spark
742+
as the primary application. This is translated to the `code`
743+
property on the returned `RunArgs` object.
744+
submit_py_files (list[str]): List of paths (local or S3) to provide for
745+
`spark-submit --py-files` option
746+
submit_jars (list[str]): List of paths (local or S3) to provide for
747+
`spark-submit --jars` option
748+
submit_files (list[str]): List of paths (local or S3) to provide for
749+
`spark-submit --files` option
750+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
751+
the processing job. These must be provided as
752+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
753+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
754+
the processing job. These can be specified as either path strings or
755+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
756+
arguments (list[str]): A list of string arguments to be passed to a
757+
processing job (default: None).
758+
job_name (str): Processing job name. If not specified, the processor generates
759+
a default job name, based on the base job name and current timestamp.
760+
configuration (list[dict] or dict): Configuration for Hadoop, Spark, or Hive.
761+
List or dictionary of EMR-style classifications.
762+
https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
763+
spark_event_logs_s3_uri (str): S3 path where spark application events will
764+
be published to.
765+
"""
766+
self._current_job_name = self._generate_current_job_name(job_name=job_name)
767+
768+
if not submit_app:
769+
raise ValueError("submit_app is required")
770+
771+
extended_inputs, extended_outputs = self._extend_processing_args(
772+
inputs=inputs,
773+
outputs=outputs,
774+
submit_py_files=submit_py_files,
775+
submit_jars=submit_jars,
776+
submit_files=submit_files,
777+
configuration=configuration,
778+
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
779+
)
780+
781+
return super().get_run_args(
782+
code=submit_app,
783+
inputs=extended_inputs,
784+
outputs=extended_outputs,
785+
arguments=arguments,
786+
)
787+
688788
def run(
689789
self,
690790
submit_app,
@@ -738,14 +838,13 @@ def run(
738838
user code file (default: None).
739839
"""
740840
self._current_job_name = self._generate_current_job_name(job_name=job_name)
741-
self.command = [_SparkProcessorBase._default_command]
742841

743842
if not submit_app:
744843
raise ValueError("submit_app is required")
745844

746845
extended_inputs, extended_outputs = self._extend_processing_args(
747-
inputs,
748-
outputs,
846+
inputs=inputs,
847+
outputs=outputs,
749848
submit_py_files=submit_py_files,
750849
submit_jars=submit_jars,
751850
submit_files=submit_files,
@@ -762,6 +861,7 @@ def run(
762861
logs=logs,
763862
job_name=self._current_job_name,
764863
experiment_config=experiment_config,
864+
kms_key=kms_key,
765865
)
766866

767867
def _extend_processing_args(self, inputs, outputs, **kwargs):
@@ -772,6 +872,7 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
772872
outputs: Processing outputs.
773873
kwargs: Additional keyword arguments passed to `super()`.
774874
"""
875+
self.command = [_SparkProcessorBase._default_command]
775876
extended_inputs = self._handle_script_dependencies(
776877
inputs, kwargs.get("submit_py_files"), FileType.PYTHON
777878
)
@@ -866,6 +967,73 @@ def __init__(
866967
network_config=network_config,
867968
)
868969

970+
def get_run_args(
971+
self,
972+
submit_app,
973+
submit_class=None,
974+
submit_jars=None,
975+
submit_files=None,
976+
inputs=None,
977+
outputs=None,
978+
arguments=None,
979+
job_name=None,
980+
configuration=None,
981+
spark_event_logs_s3_uri=None,
982+
):
983+
"""Returns a RunArgs object.
984+
985+
This object contains the normalized inputs, outputs
986+
and arguments needed when using a ``SparkJarProcessor``
987+
in a :class:`~sagemaker.workflow.steps.ProcessingStep`.
988+
989+
Args:
990+
submit_app (str): Path (local or S3) to Python file to submit to Spark
991+
as the primary application. This is translated to the `code`
992+
property on the returned `RunArgs` object
993+
submit_class (str): Java class reference to submit to Spark as the primary
994+
application
995+
submit_jars (list[str]): List of paths (local or S3) to provide for
996+
`spark-submit --jars` option
997+
submit_files (list[str]): List of paths (local or S3) to provide for
998+
`spark-submit --files` option
999+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
1000+
the processing job. These must be provided as
1001+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
1002+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
1003+
the processing job. These can be specified as either path strings or
1004+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
1005+
arguments (list[str]): A list of string arguments to be passed to a
1006+
processing job (default: None).
1007+
job_name (str): Processing job name. If not specified, the processor generates
1008+
a default job name, based on the base job name and current timestamp.
1009+
configuration (list[dict] or dict): Configuration for Hadoop, Spark, or Hive.
1010+
List or dictionary of EMR-style classifications.
1011+
https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
1012+
spark_event_logs_s3_uri (str): S3 path where spark application events will
1013+
be published to.
1014+
"""
1015+
self._current_job_name = self._generate_current_job_name(job_name=job_name)
1016+
1017+
if not submit_app:
1018+
raise ValueError("submit_app is required")
1019+
1020+
extended_inputs, extended_outputs = self._extend_processing_args(
1021+
inputs=inputs,
1022+
outputs=outputs,
1023+
submit_class=submit_class,
1024+
submit_jars=submit_jars,
1025+
submit_files=submit_files,
1026+
configuration=configuration,
1027+
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
1028+
)
1029+
1030+
return super().get_run_args(
1031+
code=submit_app,
1032+
inputs=extended_inputs,
1033+
outputs=extended_outputs,
1034+
arguments=arguments,
1035+
)
1036+
8691037
def run(
8701038
self,
8711039
submit_app,
@@ -919,14 +1087,13 @@ def run(
9191087
user code file (default: None).
9201088
"""
9211089
self._current_job_name = self._generate_current_job_name(job_name=job_name)
922-
self.command = [_SparkProcessorBase._default_command]
9231090

9241091
if not submit_app:
9251092
raise ValueError("submit_app is required")
9261093

9271094
extended_inputs, extended_outputs = self._extend_processing_args(
928-
inputs,
929-
outputs,
1095+
inputs=inputs,
1096+
outputs=outputs,
9301097
submit_class=submit_class,
9311098
submit_jars=submit_jars,
9321099
submit_files=submit_files,
@@ -947,6 +1114,7 @@ def run(
9471114
)
9481115

9491116
def _extend_processing_args(self, inputs, outputs, **kwargs):
1117+
self.command = [_SparkProcessorBase._default_command]
9501118
if kwargs.get("submit_class"):
9511119
self.command.extend(["--class", kwargs.get("submit_class")])
9521120
else:

src/sagemaker/workflow/steps.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ def arguments(self) -> RequestType:
371371
outputs=self.outputs,
372372
code=self.code,
373373
)
374+
374375
process_args = ProcessingJob._get_process_args(
375376
self.processor, normalized_inputs, normalized_outputs, experiment_config=dict()
376377
)

0 commit comments

Comments
 (0)