Skip to content

Commit 37a12ee

Browse files
author
Payton Staub
committed
Interim commit #4 - support all processors for ProcessingStep
1 parent 81dfee2 commit 37a12ee

File tree

6 files changed

+47
-118
lines changed

6 files changed

+47
-118
lines changed

src/sagemaker/processing.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -123,25 +123,6 @@ def __init__(
123123

124124
self.sagemaker_session = sagemaker_session or Session()
125125

126-
def get_run_args(
127-
self,
128-
inputs=None,
129-
outputs=None,
130-
arguments=None,
131-
job_name=None,
132-
kms_key=None,
133-
):
134-
# TODO: description
135-
normalized_inputs, normalized_outputs = self._normalize_args(
136-
job_name=job_name,
137-
arguments=arguments,
138-
inputs=inputs,
139-
kms_key=kms_key,
140-
outputs=outputs,
141-
)
142-
143-
return RunArgs(inputs=normalized_inputs, outputs=normalized_outputs, code=None)
144-
145126
def run(
146127
self,
147128
inputs=None,
@@ -1190,34 +1171,34 @@ def _to_request_dict(self):
11901171

11911172

11921173
class RunArgs(object):
1193-
"""Accepts parameters that specify an Amazon S3 output for a processing job.
1174+
"""Provides an object containing the standard run arguments needed by
1175+
:class:`~sagemaker.processing.ScriptProcessor`.
11941176
1195-
It also provides a method to turn those parameters into a dictionary.
1177+
An instance of this class is returned from the ``get_run_args()`` method on processors,
1178+
and is used for normalizing the arguments so that they can be passed to
1179+
:class:`~sagemaker.workflow.steps.ProcessingStep`
11961180
"""
11971181

11981182
def __init__(
11991183
self,
1184+
code=None,
12001185
inputs=None,
12011186
outputs=None,
1202-
code=None,
12031187
arguments=None,
12041188
):
1205-
"""Initializes a ``ProcessingOutput`` instance.
1206-
1207-
``ProcessingOutput`` accepts parameters that specify an Amazon S3 output for a
1208-
processing job and provides a method to turn those parameters into a dictionary.
1189+
"""Initializes a ``RunArgs`` instance.
12091190
12101191
Args:
1211-
source (str): The source for the output.
1212-
destination (str): The destination of the output. If a destination
1213-
is not provided, one will be generated:
1214-
"s3://<default-bucket-name>/<job-name>/output/<output-name>".
1215-
output_name (str): The name of the output. If a name
1216-
is not provided, one will be generated (eg. "output-1").
1217-
s3_upload_mode (str): Valid options are "EndOfJob" or "Continuous".
1218-
app_managed (bool): Whether the input are managed by SageMaker or application
1219-
feature_store_output (:class:`~sagemaker.processing.FeatureStoreOutput`)
1220-
Configuration for processing job outputs of FeatureStore.
1192+
code (str): This can be an S3 URI or a local path to a file with the framework
1193+
script to run.
1194+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
1195+
the processing job. These must be provided as
1196+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
1197+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
1198+
the processing job. These can be specified as either path strings or
1199+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
1200+
arguments (list[str]): A list of string arguments to be passed to a
1201+
processing job (default: None).
12211202
"""
12221203
self.inputs = inputs
12231204
self.outputs = outputs

src/sagemaker/spark/processing.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ def get_run_args(
730730
spark_event_logs_s3_uri=None,
731731
):
732732
"""Returns a RunArgs object. This object contains the normalized inputs, outputs
733-
and arguments needed when creating using a ``PySparkProcessor`` in a :class:`~sagemaker.workflow.steps.ProcessingStep`.
733+
and arguments needed when using a ``PySparkProcessor`` in a :class:`~sagemaker.workflow.steps.ProcessingStep`.
734734
735735
Args:
736736
submit_app (str): Path (local or S3) to Python file to submit to Spark
@@ -758,7 +758,6 @@ def get_run_args(
758758
be published to.
759759
"""
760760
self._current_job_name = self._generate_current_job_name(job_name=job_name)
761-
self.command = [_SparkProcessorBase._default_command]
762761

763762
if not submit_app:
764763
raise ValueError("submit_app is required")
@@ -833,7 +832,6 @@ def run(
833832
user code file (default: None).
834833
"""
835834
self._current_job_name = self._generate_current_job_name(job_name=job_name)
836-
self.command = [_SparkProcessorBase._default_command]
837835

838836
if not submit_app:
839837
raise ValueError("submit_app is required")
@@ -868,6 +866,7 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
868866
outputs: Processing outputs.
869867
kwargs: Additional keyword arguments passed to `super()`.
870868
"""
869+
self.command = [_SparkProcessorBase._default_command]
871870
extended_inputs = self._handle_script_dependencies(
872871
inputs, kwargs.get("submit_py_files"), FileType.PYTHON
873872
)
@@ -976,7 +975,7 @@ def get_run_args(
976975
spark_event_logs_s3_uri=None,
977976
):
978977
"""Returns a RunArgs object. This object contains the normalized inputs, outputs
979-
and arguments needed when creating using a ``SparkJarProcessor`` in a :class:`~sagemaker.workflow.steps.ProcessingStep`.
978+
and arguments needed when using a ``SparkJarProcessor`` in a :class:`~sagemaker.workflow.steps.ProcessingStep`.
980979
981980
Args:
982981
submit_app (str): Path (local or S3) to Python file to submit to Spark
@@ -1004,7 +1003,6 @@ def get_run_args(
10041003
be published to.
10051004
"""
10061005
self._current_job_name = self._generate_current_job_name(job_name=job_name)
1007-
self.command = [_SparkProcessorBase._default_command]
10081006

10091007
if not submit_app:
10101008
raise ValueError("submit_app is required")
@@ -1079,7 +1077,6 @@ def run(
10791077
user code file (default: None).
10801078
"""
10811079
self._current_job_name = self._generate_current_job_name(job_name=job_name)
1082-
self.command = [_SparkProcessorBase._default_command]
10831080

10841081
if not submit_app:
10851082
raise ValueError("submit_app is required")
@@ -1107,6 +1104,7 @@ def run(
11071104
)
11081105

11091106
def _extend_processing_args(self, inputs, outputs, **kwargs):
1107+
self.command = [_SparkProcessorBase._default_command]
11101108
if kwargs.get("submit_class"):
11111109
self.command.extend(["--class", kwargs.get("submit_class")])
11121110
else:

tests/integ/test_workflow.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def athena_dataset_definition(sagemaker_session):
104104
),
105105
)
106106

107+
107108
@pytest.fixture
108109
def configuration() -> list:
109110
configuration = [
@@ -183,6 +184,7 @@ def configuration() -> list:
183184
]
184185
return configuration
185186

187+
186188
@pytest.fixture(scope="module")
187189
def build_jar():
188190
spark_path = os.path.join(DATA_DIR, "spark")
@@ -224,6 +226,7 @@ def build_jar():
224226
subprocess.run(["rm", os.path.join(jar_file_path, "hello-spark-java.jar")])
225227
subprocess.run(["rm", os.path.join(jar_file_path, java_file_path, "HelloJavaSparkApp.class")])
226228

229+
227230
def test_three_step_definition(
228231
sagemaker_session,
229232
region_name,
@@ -473,6 +476,7 @@ def test_one_step_sklearn_processing_pipeline(
473476
except Exception:
474477
pass
475478

479+
476480
def test_one_step_pyspark_processing_pipeline(
477481
sagemaker_session,
478482
role,
@@ -496,12 +500,18 @@ def test_one_step_pyspark_processing_pipeline(
496500
)
497501

498502
spark_run_args = pyspark_processor.get_run_args(
499-
submit_app=script_path,
500-
arguments=["--s3_input_bucket", sagemaker_session.default_bucket(),
501-
"--s3_input_key_prefix", "spark-input",
502-
"--s3_output_bucket", sagemaker_session.default_bucket(),
503-
"--s3_output_key_prefix", "spark-output"],
504-
)
503+
submit_app=script_path,
504+
arguments=[
505+
"--s3_input_bucket",
506+
sagemaker_session.default_bucket(),
507+
"--s3_input_key_prefix",
508+
"spark-input",
509+
"--s3_output_bucket",
510+
sagemaker_session.default_bucket(),
511+
"--s3_output_key_prefix",
512+
"spark-output",
513+
],
514+
)
505515

506516
step_pyspark = ProcessingStep(
507517
name="pyspark-process",
@@ -520,12 +530,12 @@ def test_one_step_pyspark_processing_pipeline(
520530
)
521531

522532
try:
523-
# NOTE: We should exercise the case when role used in the pipeline execution is
524-
# different than that required of the steps in the pipeline itself. The role in
525-
# the pipeline definition needs to create training and processing jobs and other
526-
# sagemaker entities. However, the jobs created in the steps themselves execute
527-
# under a potentially different role, often requiring access to S3 and other
528-
# artifacts not required to during creation of the jobs in the pipeline steps.
533+
# NOTE: We should exercise the case when role used in the pipeline execution is
534+
# different than that required of the steps in the pipeline itself. The role in
535+
# the pipeline definition needs to create training and processing jobs and other
536+
# sagemaker entities. However, the jobs created in the steps themselves execute
537+
# under a potentially different role, often requiring access to S3 and other
538+
# artifacts not required to during creation of the jobs in the pipeline steps.
529539
response = pipeline.create(role)
530540
create_arn = response["PipelineArn"]
531541
assert re.match(
@@ -568,14 +578,9 @@ def test_one_step_pyspark_processing_pipeline(
568578
except Exception:
569579
pass
570580

581+
571582
def test_one_step_sparkjar_processing_pipeline(
572-
sagemaker_session,
573-
role,
574-
cpu_instance_type,
575-
pipeline_name,
576-
region_name,
577-
configuration,
578-
build_jar
583+
sagemaker_session, role, cpu_instance_type, pipeline_name, region_name, configuration, build_jar
579584
):
580585
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
581586
cache_config = CacheConfig(enable_caching=True, expire_after="T30m")
@@ -670,6 +675,7 @@ def test_one_step_sparkjar_processing_pipeline(
670675
except Exception:
671676
pass
672677

678+
673679
def test_conditional_pytorch_training_model_registration(
674680
sagemaker_session,
675681
role,

tests/unit/sagemaker/spark/test_processing.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,6 @@ def test_py_spark_processor_run(
776776
"inputs": [],
777777
"opt": None,
778778
"arguments": ["arg1"],
779-
"kms_key": "test_kms_key",
780779
},
781780
ValueError,
782781
),
@@ -787,7 +786,6 @@ def test_py_spark_processor_run(
787786
"inputs": [processing_input],
788787
"opt": None,
789788
"arguments": ["arg1"],
790-
"kms_key": "test_kms_key",
791789
},
792790
[processing_input],
793791
),
@@ -798,7 +796,6 @@ def test_py_spark_processor_run(
798796
"inputs": [processing_input],
799797
"opt": None,
800798
"arguments": ["arg1"],
801-
"kms_key": "test_kms_key",
802799
},
803800
[processing_input, processing_input, processing_input, processing_input],
804801
),
@@ -809,7 +806,6 @@ def test_py_spark_processor_run(
809806
"inputs": None,
810807
"opt": None,
811808
"arguments": ["arg1"],
812-
"kms_key": "test_kms_key",
813809
},
814810
[processing_input, processing_input, processing_input],
815811
),
@@ -820,7 +816,6 @@ def test_py_spark_processor_run(
820816
"inputs": None,
821817
"opt": "opt",
822818
"arguments": ["arg1"],
823-
"kms_key": "test_kms_key",
824819
},
825820
[processing_input, processing_input, processing_input],
826821
),
@@ -878,7 +873,6 @@ def test_py_spark_processor_get_run_args(
878873
"inputs": [],
879874
"opt": None,
880875
"arguments": ["arg1"],
881-
"kms_key": "test_kms_key",
882876
},
883877
ValueError,
884878
),
@@ -889,7 +883,6 @@ def test_py_spark_processor_get_run_args(
889883
"inputs": [processing_input],
890884
"opt": None,
891885
"arguments": ["arg1"],
892-
"kms_key": "test_kms_key",
893886
},
894887
[processing_input],
895888
),
@@ -900,7 +893,6 @@ def test_py_spark_processor_get_run_args(
900893
"inputs": [processing_input],
901894
"opt": None,
902895
"arguments": ["arg1"],
903-
"kms_key": "test_kms_key",
904896
},
905897
[processing_input, processing_input, processing_input, processing_input],
906898
),
@@ -911,7 +903,6 @@ def test_py_spark_processor_get_run_args(
911903
"inputs": None,
912904
"opt": None,
913905
"arguments": ["arg1"],
914-
"kms_key": "test_kms_key",
915906
},
916907
[processing_input, processing_input, processing_input],
917908
),
@@ -922,7 +913,6 @@ def test_py_spark_processor_get_run_args(
922913
"inputs": None,
923914
"opt": "opt",
924915
"arguments": ["arg1"],
925-
"kms_key": "test_kms_key",
926916
},
927917
[processing_input, processing_input, processing_input],
928918
),
@@ -951,7 +941,6 @@ def test_py_spark_processor_get_run_args(
951941
submit_files=config["files"],
952942
inputs=config["inputs"],
953943
arguments=config["arguments"],
954-
kms_key=config["kms_key"],
955944
)
956945
else:
957946
py_spark_processor.get_run_args(
@@ -961,16 +950,13 @@ def test_py_spark_processor_get_run_args(
961950
submit_files=config["files"],
962951
inputs=config["inputs"],
963952
arguments=config["arguments"],
964-
kms_key=config["kms_key"],
965953
)
966954

967955
mock_super_get_run_args.assert_called_with(
968-
submit_app=config["submit_app"],
956+
code=config["submit_app"],
969957
inputs=expected,
970958
outputs=None,
971959
arguments=config["arguments"],
972-
job_name="jobName",
973-
kms_key=config["kms_key"],
974960
)
975961

976962

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ def test_processing_step(sagemaker_session):
183183
outputs=[],
184184
cache_config=cache_config,
185185
)
186-
print(f"StepToRequest is {step.to_request()}")
187186
assert step.to_request() == {
188187
"Name": "MyProcessingStep",
189188
"Type": "Processing",

0 commit comments

Comments
 (0)