Skip to content

feature: Support all processor types in ProcessingStep #2209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import os
import pathlib
import attr

from six.moves.urllib.parse import urlparse
from six.moves.urllib.request import url2pathname
Expand Down Expand Up @@ -207,13 +208,13 @@ def _normalize_args(
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
code (str): This can be an S3 URI or a local path to a file with the framework
script to run (default: None). A no op in the base class.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
"""
self._current_job_name = self._generate_current_job_name(job_name=job_name)

Expand Down Expand Up @@ -442,6 +443,34 @@ def __init__(
network_config=network_config,
)

def get_run_args(
self,
code,
inputs=None,
outputs=None,
arguments=None,
):
"""Returns a RunArgs object.

For processors (:class:`~sagemaker.spark.processing.PySparkProcessor`,
:class:`~sagemaker.spark.processing.SparkJar`) that have special
run() arguments, this object contains the normalized arguments for passing to
:class:`~sagemaker.workflow.steps.ProcessingStep`.

Args:
code (str): This can be an S3 URI or a local path to a file with the framework
script to run.
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
arguments (list[str]): A list of string arguments to be passed to a
processing job (default: None).
"""
return RunArgs(code=code, inputs=inputs, outputs=outputs, arguments=arguments)

def run(
self,
code,
Expand Down Expand Up @@ -1144,6 +1173,33 @@ def _to_request_dict(self):
return s3_output_request


@attr.s
class RunArgs(object):
"""Accepts parameters that correspond to ScriptProcessors.

An instance of this class is returned from the ``get_run_args()`` method on processors,
and is used for normalizing the arguments so that they can be passed to
:class:`~sagemaker.workflow.steps.ProcessingStep`

Args:
code (str): This can be an S3 URI or a local path to a file with the framework
script to run.
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
arguments (list[str]): A list of string arguments to be passed to a
processing job (default: None).
"""

code = attr.ib()
inputs = attr.ib(default=None)
outputs = attr.ib(default=None)
arguments = attr.ib(default=None)


class FeatureStoreOutput(ApiObject):
"""Configuration for processing job outputs in Amazon SageMaker Feature Store."""

Expand Down
180 changes: 174 additions & 6 deletions src/sagemaker/spark/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,39 @@ def __init__(
network_config=network_config,
)

def get_run_args(
self,
code,
inputs=None,
outputs=None,
arguments=None,
):
"""Returns a RunArgs object.

For processors (:class:`~sagemaker.spark.processing.PySparkProcessor`,
:class:`~sagemaker.spark.processing.SparkJar`) that have special
run() arguments, this object contains the normalized arguments for passing to
:class:`~sagemaker.workflow.steps.ProcessingStep`.

Args:
code (str): This can be an S3 URI or a local path to a file with the framework
script to run.
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
arguments (list[str]): A list of string arguments to be passed to a
processing job (default: None).
"""
return super().get_run_args(
code=code,
inputs=inputs,
outputs=outputs,
arguments=arguments,
)

def run(
self,
submit_app,
Expand Down Expand Up @@ -685,6 +718,73 @@ def __init__(
network_config=network_config,
)

def get_run_args(
self,
submit_app,
submit_py_files=None,
submit_jars=None,
submit_files=None,
inputs=None,
outputs=None,
arguments=None,
job_name=None,
configuration=None,
spark_event_logs_s3_uri=None,
):
"""Returns a RunArgs object.

This object contains the normalized inputs, outputs
and arguments needed when using a ``PySparkProcessor``
in a :class:`~sagemaker.workflow.steps.ProcessingStep`.

Args:
submit_app (str): Path (local or S3) to Python file to submit to Spark
as the primary application. This is translated to the `code`
property on the returned `RunArgs` object.
submit_py_files (list[str]): List of paths (local or S3) to provide for
`spark-submit --py-files` option
submit_jars (list[str]): List of paths (local or S3) to provide for
`spark-submit --jars` option
submit_files (list[str]): List of paths (local or S3) to provide for
`spark-submit --files` option
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
arguments (list[str]): A list of string arguments to be passed to a
processing job (default: None).
job_name (str): Processing job name. If not specified, the processor generates
a default job name, based on the base job name and current timestamp.
configuration (list[dict] or dict): Configuration for Hadoop, Spark, or Hive.
List or dictionary of EMR-style classifications.
https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
spark_event_logs_s3_uri (str): S3 path where spark application events will
be published to.
"""
self._current_job_name = self._generate_current_job_name(job_name=job_name)

if not submit_app:
raise ValueError("submit_app is required")

extended_inputs, extended_outputs = self._extend_processing_args(
inputs=inputs,
outputs=outputs,
submit_py_files=submit_py_files,
submit_jars=submit_jars,
submit_files=submit_files,
configuration=configuration,
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
)

return super().get_run_args(
code=submit_app,
inputs=extended_inputs,
outputs=extended_outputs,
arguments=arguments,
)

def run(
self,
submit_app,
Expand Down Expand Up @@ -738,14 +838,13 @@ def run(
user code file (default: None).
"""
self._current_job_name = self._generate_current_job_name(job_name=job_name)
self.command = [_SparkProcessorBase._default_command]

if not submit_app:
raise ValueError("submit_app is required")

extended_inputs, extended_outputs = self._extend_processing_args(
inputs,
outputs,
inputs=inputs,
outputs=outputs,
submit_py_files=submit_py_files,
submit_jars=submit_jars,
submit_files=submit_files,
Expand All @@ -762,6 +861,7 @@ def run(
logs=logs,
job_name=self._current_job_name,
experiment_config=experiment_config,
kms_key=kms_key,
)

def _extend_processing_args(self, inputs, outputs, **kwargs):
Expand All @@ -772,6 +872,7 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
outputs: Processing outputs.
kwargs: Additional keyword arguments passed to `super()`.
"""
self.command = [_SparkProcessorBase._default_command]
extended_inputs = self._handle_script_dependencies(
inputs, kwargs.get("submit_py_files"), FileType.PYTHON
)
Expand Down Expand Up @@ -866,6 +967,73 @@ def __init__(
network_config=network_config,
)

def get_run_args(
self,
submit_app,
submit_class=None,
submit_jars=None,
submit_files=None,
inputs=None,
outputs=None,
arguments=None,
job_name=None,
configuration=None,
spark_event_logs_s3_uri=None,
):
"""Returns a RunArgs object.

This object contains the normalized inputs, outputs
and arguments needed when using a ``SparkJarProcessor``
in a :class:`~sagemaker.workflow.steps.ProcessingStep`.

Args:
submit_app (str): Path (local or S3) to Python file to submit to Spark
as the primary application. This is translated to the `code`
property on the returned `RunArgs` object
submit_class (str): Java class reference to submit to Spark as the primary
application
submit_jars (list[str]): List of paths (local or S3) to provide for
`spark-submit --jars` option
submit_files (list[str]): List of paths (local or S3) to provide for
`spark-submit --files` option
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
arguments (list[str]): A list of string arguments to be passed to a
processing job (default: None).
job_name (str): Processing job name. If not specified, the processor generates
a default job name, based on the base job name and current timestamp.
configuration (list[dict] or dict): Configuration for Hadoop, Spark, or Hive.
List or dictionary of EMR-style classifications.
https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
spark_event_logs_s3_uri (str): S3 path where spark application events will
be published to.
"""
self._current_job_name = self._generate_current_job_name(job_name=job_name)

if not submit_app:
raise ValueError("submit_app is required")

extended_inputs, extended_outputs = self._extend_processing_args(
inputs=inputs,
outputs=outputs,
submit_class=submit_class,
submit_jars=submit_jars,
submit_files=submit_files,
configuration=configuration,
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
)

return super().get_run_args(
code=submit_app,
inputs=extended_inputs,
outputs=extended_outputs,
arguments=arguments,
)

def run(
self,
submit_app,
Expand Down Expand Up @@ -919,14 +1087,13 @@ def run(
user code file (default: None).
"""
self._current_job_name = self._generate_current_job_name(job_name=job_name)
self.command = [_SparkProcessorBase._default_command]

if not submit_app:
raise ValueError("submit_app is required")

extended_inputs, extended_outputs = self._extend_processing_args(
inputs,
outputs,
inputs=inputs,
outputs=outputs,
submit_class=submit_class,
submit_jars=submit_jars,
submit_files=submit_files,
Expand All @@ -947,6 +1114,7 @@ def run(
)

def _extend_processing_args(self, inputs, outputs, **kwargs):
self.command = [_SparkProcessorBase._default_command]
if kwargs.get("submit_class"):
self.command.extend(["--class", kwargs.get("submit_class")])
else:
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/workflow/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def arguments(self) -> RequestType:
outputs=self.outputs,
code=self.code,
)

process_args = ProcessingJob._get_process_args(
self.processor, normalized_inputs, normalized_outputs, experiment_config=dict()
)
Expand Down
Loading