Skip to content

Commit 229ead5

Browse files
author
Payton Staub
committed
Interim commit: support all processor types in ProcessingStep
1 parent eabf1bd commit 229ead5

File tree

3 files changed

+420
-1
lines changed

3 files changed

+420
-1
lines changed

src/sagemaker/processing.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,25 @@ def __init__(
123123

124124
self.sagemaker_session = sagemaker_session or Session()
125125

126+
def get_run_args(
127+
self,
128+
inputs=None,
129+
outputs=None,
130+
arguments=None,
131+
job_name=None,
132+
kms_key=None,
133+
):
134+
# TODO: description
135+
normalized_inputs, normalized_outputs = self._normalize_args(
136+
job_name=job_name,
137+
arguments=arguments,
138+
inputs=inputs,
139+
kms_key=kms_key,
140+
outputs=outputs,
141+
)
142+
143+
return RunArgs(inputs=normalized_inputs, outputs=normalized_outputs, code=None)
144+
126145
def run(
127146
self,
128147
inputs=None,
@@ -442,6 +461,27 @@ def __init__(
442461
network_config=network_config,
443462
)
444463

464+
def get_run_args(
465+
self,
466+
code,
467+
inputs=None,
468+
outputs=None,
469+
arguments=None,
470+
job_name=None,
471+
kms_key=None,
472+
):
473+
# TODO: description
474+
normalized_inputs, normalized_outputs = self._normalize_args(
475+
job_name=job_name,
476+
arguments=arguments,
477+
inputs=inputs,
478+
outputs=outputs,
479+
code=code,
480+
kms_key=kms_key,
481+
)
482+
483+
return RunArgs(inputs=normalized_inputs, outputs=normalized_outputs, code=code)
484+
445485
def run(
446486
self,
447487
code,
@@ -1144,6 +1184,40 @@ def _to_request_dict(self):
11441184
return s3_output_request
11451185

11461186

1187+
class RunArgs(object):
1188+
"""Accepts parameters that specify an Amazon S3 output for a processing job.
1189+
1190+
It also provides a method to turn those parameters into a dictionary.
1191+
"""
1192+
1193+
def __init__(
1194+
self,
1195+
inputs=None,
1196+
outputs=None,
1197+
code=None,
1198+
):
1199+
"""Initializes a ``ProcessingOutput`` instance.
1200+
1201+
``ProcessingOutput`` accepts parameters that specify an Amazon S3 output for a
1202+
processing job and provides a method to turn those parameters into a dictionary.
1203+
1204+
Args:
1205+
source (str): The source for the output.
1206+
destination (str): The destination of the output. If a destination
1207+
is not provided, one will be generated:
1208+
"s3://<default-bucket-name>/<job-name>/output/<output-name>".
1209+
output_name (str): The name of the output. If a name
1210+
is not provided, one will be generated (eg. "output-1").
1211+
s3_upload_mode (str): Valid options are "EndOfJob" or "Continuous".
1212+
app_managed (bool): Whether the input are managed by SageMaker or application
1213+
feature_store_output (:class:`~sagemaker.processing.FeatureStoreOutput`)
1214+
Configuration for processing job outputs of FeatureStore.
1215+
"""
1216+
self.inputs = inputs
1217+
self.outputs = outputs
1218+
self.code = code
1219+
1220+
11471221
class FeatureStoreOutput(ApiObject):
11481222
"""Configuration for processing job outputs in Amazon SageMaker Feature Store."""
11491223

src/sagemaker/spark/processing.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from sagemaker import image_uris
3535
from sagemaker.local.image import _ecr_login_if_needed, _pull_image
36-
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor
36+
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor, RunArgs
3737
from sagemaker.s3 import S3Uploader
3838
from sagemaker.session import Session
3939
from sagemaker.spark import defaults
@@ -171,6 +171,25 @@ def __init__(
171171
network_config=network_config,
172172
)
173173

174+
def get_run_args(
175+
self,
176+
submit_app,
177+
inputs=None,
178+
outputs=None,
179+
arguments=None,
180+
job_name=None,
181+
kms_key=None,
182+
):
183+
# TODO: description
184+
return super().get_run_args(
185+
code=submit_app,
186+
inputs=inputs,
187+
outputs=outputs,
188+
arguments=arguments,
189+
job_name=job_name,
190+
kms_key=kms_key,
191+
)
192+
174193
def run(
175194
self,
176195
submit_app,
@@ -685,6 +704,46 @@ def __init__(
685704
network_config=network_config,
686705
)
687706

707+
def get_run_args(
708+
self,
709+
submit_app,
710+
submit_py_files=None,
711+
submit_jars=None,
712+
submit_files=None,
713+
inputs=None,
714+
outputs=None,
715+
arguments=None,
716+
job_name=None,
717+
configuration=None,
718+
spark_event_logs_s3_uri=None,
719+
kms_key=None,
720+
):
721+
self._current_job_name = self._generate_current_job_name(job_name=job_name)
722+
self.command = [_SparkProcessorBase._default_command]
723+
724+
if not submit_app:
725+
raise ValueError("submit_app is required")
726+
727+
extended_inputs, extended_outputs = self._extend_processing_args(
728+
inputs,
729+
outputs,
730+
submit_py_files=submit_py_files,
731+
submit_jars=submit_jars,
732+
submit_files=submit_files,
733+
configuration=configuration,
734+
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
735+
)
736+
737+
# TODO: description
738+
return super().get_run_args(
739+
submit_app=submit_app,
740+
inputs=extended_inputs,
741+
outputs=extended_outputs,
742+
arguments=arguments,
743+
job_name=self._current_job_name,
744+
kms_key=kms_key,
745+
)
746+
688747
def run(
689748
self,
690749
submit_app,
@@ -866,6 +925,46 @@ def __init__(
866925
network_config=network_config,
867926
)
868927

928+
def get_run_args(
929+
self,
930+
submit_app,
931+
submit_class=None,
932+
submit_jars=None,
933+
submit_files=None,
934+
inputs=None,
935+
outputs=None,
936+
arguments=None,
937+
job_name=None,
938+
configuration=None,
939+
spark_event_logs_s3_uri=None,
940+
kms_key=None,
941+
):
942+
self._current_job_name = self._generate_current_job_name(job_name=job_name)
943+
self.command = [_SparkProcessorBase._default_command]
944+
945+
if not submit_app:
946+
raise ValueError("submit_app is required")
947+
948+
extended_inputs, extended_outputs = self._extend_processing_args(
949+
inputs,
950+
outputs,
951+
submit_class=submit_class,
952+
submit_jars=submit_jars,
953+
submit_files=submit_files,
954+
configuration=configuration,
955+
spark_event_logs_s3_uri=spark_event_logs_s3_uri,
956+
)
957+
958+
# TODO: description
959+
return super().get_run_args(
960+
submit_app=submit_app,
961+
inputs=extended_inputs,
962+
outputs=extended_outputs,
963+
arguments=arguments,
964+
job_name=self._current_job_name,
965+
kms_key=kms_key,
966+
)
967+
869968
def run(
870969
self,
871970
submit_app,

0 commit comments

Comments
 (0)