Skip to content

Commit fbd6d28

Browse files
committed
add input parameterization tests for workflow job steps
1 parent 4acbdb0 commit fbd6d28

File tree

3 files changed

+300
-231
lines changed

3 files changed

+300
-231
lines changed

tests/unit/sagemaker/workflow/test_processing_step.py

Lines changed: 157 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sagemaker.tuner import HyperparameterTuner
2525
from sagemaker.workflow.pipeline_context import PipelineSession
2626

27-
from sagemaker.processing import Processor, ScriptProcessor, FrameworkProcessor
27+
from sagemaker.processing import Processor, ScriptProcessor, FrameworkProcessor, ProcessingOutput, ProcessingInput
2828
from sagemaker.sklearn.processing import SKLearnProcessor
2929
from sagemaker.pytorch.processing import PyTorchProcessor
3030
from sagemaker.tensorflow.processing import TensorFlowProcessor
@@ -34,11 +34,12 @@
3434
from sagemaker.wrangler.processing import DataWranglerProcessor
3535
from sagemaker.spark.processing import SparkJarProcessor, PySparkProcessor
3636

37-
from sagemaker.processing import ProcessingInput
3837

3938
from sagemaker.workflow.steps import CacheConfig, ProcessingStep
4039
from sagemaker.workflow.pipeline import Pipeline
4140
from sagemaker.workflow.properties import PropertyFile
41+
from sagemaker.workflow.parameters import ParameterString
42+
from sagemaker.workflow.functions import Join
4243

4344
from sagemaker.network import NetworkConfig
4445
from sagemaker.pytorch.estimator import PyTorch
@@ -62,6 +63,140 @@
6263
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
6364
INSTANCE_TYPE = "ml.m4.xlarge"
6465

66+
FRAMEWORK_PROCESSOR = [
67+
(
68+
FrameworkProcessor(
69+
framework_version="1.8",
70+
instance_type=INSTANCE_TYPE,
71+
instance_count=1,
72+
role=ROLE,
73+
estimator_cls=PyTorch,
74+
),
75+
{"code": DUMMY_S3_SCRIPT_PATH},
76+
),
77+
(
78+
SKLearnProcessor(
79+
framework_version="0.23-1",
80+
instance_type=INSTANCE_TYPE,
81+
instance_count=1,
82+
role=ROLE,
83+
),
84+
{"code": DUMMY_S3_SCRIPT_PATH},
85+
),
86+
(
87+
PyTorchProcessor(
88+
role=ROLE,
89+
instance_type=INSTANCE_TYPE,
90+
instance_count=1,
91+
framework_version="1.8.0",
92+
py_version="py3",
93+
),
94+
{"code": DUMMY_S3_SCRIPT_PATH},
95+
),
96+
(
97+
TensorFlowProcessor(
98+
role=ROLE,
99+
instance_type=INSTANCE_TYPE,
100+
instance_count=1,
101+
framework_version="2.0",
102+
),
103+
{"code": DUMMY_S3_SCRIPT_PATH},
104+
),
105+
(
106+
HuggingFaceProcessor(
107+
transformers_version="4.6",
108+
pytorch_version="1.7",
109+
role=ROLE,
110+
instance_count=1,
111+
instance_type="ml.p3.2xlarge",
112+
),
113+
{"code": DUMMY_S3_SCRIPT_PATH},
114+
),
115+
(
116+
XGBoostProcessor(
117+
framework_version="1.3-1",
118+
py_version="py3",
119+
role=ROLE,
120+
instance_count=1,
121+
instance_type=INSTANCE_TYPE,
122+
base_job_name="test-xgboost",
123+
),
124+
{"code": DUMMY_S3_SCRIPT_PATH},
125+
),
126+
(
127+
MXNetProcessor(
128+
framework_version="1.4.1",
129+
py_version="py3",
130+
role=ROLE,
131+
instance_count=1,
132+
instance_type=INSTANCE_TYPE,
133+
base_job_name="test-mxnet",
134+
),
135+
{"code": DUMMY_S3_SCRIPT_PATH},
136+
),
137+
(
138+
DataWranglerProcessor(
139+
role=ROLE,
140+
data_wrangler_flow_source="s3://my-bucket/dw.flow",
141+
instance_count=1,
142+
instance_type=INSTANCE_TYPE,
143+
),
144+
{},
145+
),
146+
(
147+
SparkJarProcessor(
148+
role=ROLE,
149+
framework_version="2.4",
150+
instance_count=1,
151+
instance_type=INSTANCE_TYPE,
152+
),
153+
{
154+
"submit_app": "s3://my-jar",
155+
"submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
156+
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
157+
},
158+
),
159+
(
160+
PySparkProcessor(
161+
role=ROLE,
162+
framework_version="2.4",
163+
instance_count=1,
164+
instance_type=INSTANCE_TYPE,
165+
),
166+
{
167+
"submit_app": "s3://my-jar",
168+
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
169+
},
170+
),
171+
]
172+
173+
PROCESSING_INPUT = [
174+
ProcessingInput(source=f"s3://my-bucket/processing_manifest", destination="processing_manifest"),
175+
ProcessingInput(
176+
source=ParameterString(name="my-processing-input"),
177+
destination="processing-input",
178+
),
179+
ProcessingInput(
180+
source=ParameterString(name="my-processing-input", default_value="s3://my-bucket/my-processing"),
181+
destination="processing-input",
182+
),
183+
ProcessingInput(
184+
source=Join(on="/", values=["s3://my-bucket", "my-input"]),
185+
destination="processing-input",
186+
)
187+
]
188+
189+
PROCESSING_OUTPUT = [
190+
ProcessingOutput(source="/opt/ml/output", destination="s3://my-bucket/my-output"),
191+
ProcessingOutput(source="/opt/ml/output", destination=ParameterString(name="my-output")),
192+
ProcessingOutput(source="/opt/ml/output",
193+
destination=ParameterString(name="my-output", default_value="s3://my-bucket/my-output")),
194+
ProcessingOutput(
195+
source="/opt/ml/output",
196+
destination=Join(on="/", values=["s3://my-bucket", "my-output"]),
197+
)
198+
]
199+
65200

66201
@pytest.fixture
67202
def client():
@@ -253,117 +388,11 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu
253388
}
254389

255390

256-
@pytest.mark.parametrize(
257-
"framework_processor",
258-
[
259-
(
260-
FrameworkProcessor(
261-
framework_version="1.8",
262-
instance_type=INSTANCE_TYPE,
263-
instance_count=1,
264-
role=ROLE,
265-
estimator_cls=PyTorch,
266-
),
267-
{"code": DUMMY_S3_SCRIPT_PATH},
268-
),
269-
(
270-
SKLearnProcessor(
271-
framework_version="0.23-1",
272-
instance_type=INSTANCE_TYPE,
273-
instance_count=1,
274-
role=ROLE,
275-
),
276-
{"code": DUMMY_S3_SCRIPT_PATH},
277-
),
278-
(
279-
PyTorchProcessor(
280-
role=ROLE,
281-
instance_type=INSTANCE_TYPE,
282-
instance_count=1,
283-
framework_version="1.8.0",
284-
py_version="py3",
285-
),
286-
{"code": DUMMY_S3_SCRIPT_PATH},
287-
),
288-
(
289-
TensorFlowProcessor(
290-
role=ROLE,
291-
instance_type=INSTANCE_TYPE,
292-
instance_count=1,
293-
framework_version="2.0",
294-
),
295-
{"code": DUMMY_S3_SCRIPT_PATH},
296-
),
297-
(
298-
HuggingFaceProcessor(
299-
transformers_version="4.6",
300-
pytorch_version="1.7",
301-
role=ROLE,
302-
instance_count=1,
303-
instance_type="ml.p3.2xlarge",
304-
),
305-
{"code": DUMMY_S3_SCRIPT_PATH},
306-
),
307-
(
308-
XGBoostProcessor(
309-
framework_version="1.3-1",
310-
py_version="py3",
311-
role=ROLE,
312-
instance_count=1,
313-
instance_type=INSTANCE_TYPE,
314-
base_job_name="test-xgboost",
315-
),
316-
{"code": DUMMY_S3_SCRIPT_PATH},
317-
),
318-
(
319-
MXNetProcessor(
320-
framework_version="1.4.1",
321-
py_version="py3",
322-
role=ROLE,
323-
instance_count=1,
324-
instance_type=INSTANCE_TYPE,
325-
base_job_name="test-mxnet",
326-
),
327-
{"code": DUMMY_S3_SCRIPT_PATH},
328-
),
329-
(
330-
DataWranglerProcessor(
331-
role=ROLE,
332-
data_wrangler_flow_source=f"s3://{BUCKET}/dw.flow",
333-
instance_count=1,
334-
instance_type=INSTANCE_TYPE,
335-
),
336-
{},
337-
),
338-
(
339-
SparkJarProcessor(
340-
role=ROLE,
341-
framework_version="2.4",
342-
instance_count=1,
343-
instance_type=INSTANCE_TYPE,
344-
),
345-
{
346-
"submit_app": "s3://my-jar",
347-
"submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
348-
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
349-
},
350-
),
351-
(
352-
PySparkProcessor(
353-
role=ROLE,
354-
framework_version="2.4",
355-
instance_count=1,
356-
instance_type=INSTANCE_TYPE,
357-
),
358-
{
359-
"submit_app": "s3://my-jar",
360-
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
361-
},
362-
),
363-
],
364-
)
391+
@pytest.mark.parametrize("framework_processor", FRAMEWORK_PROCESSOR)
392+
@pytest.mark.parametrize("processing_input", PROCESSING_INPUT)
393+
@pytest.mark.parametrize("processing_output", PROCESSING_OUTPUT)
365394
def test_processing_step_with_framework_processor(
366-
framework_processor, pipeline_session, processing_input, network_config
395+
framework_processor, pipeline_session, processing_input, processing_output, network_config
367396
):
368397

369398
processor, run_inputs = framework_processor
@@ -373,7 +402,8 @@ def test_processing_step_with_framework_processor(
373402
processor.volume_kms_key = "volume-kms-key"
374403
processor.network_config = network_config
375404

376-
run_inputs["inputs"] = processing_input
405+
run_inputs["inputs"] = [processing_input]
406+
run_inputs["outputs"] = [processing_output]
377407

378408
step_args = processor.run(**run_inputs)
379409

@@ -387,10 +417,22 @@ def test_processing_step_with_framework_processor(
387417
sagemaker_session=pipeline_session,
388418
)
389419

390-
assert json.loads(pipeline.definition())["Steps"][0] == {
420+
step_args = step_args.args
421+
step_def = json.loads(pipeline.definition())["Steps"][0]
422+
423+
assert step_args['ProcessingInputs'][0]['S3Input']['S3Uri'] == processing_input.source
424+
assert step_args['ProcessingOutputConfig']['Outputs'][0]['S3Output']['S3Uri'] == processing_output.destination
425+
426+
del step_args['ProcessingInputs'][0]['S3Input']['S3Uri']
427+
del step_def['Arguments']['ProcessingInputs'][0]['S3Input']['S3Uri']
428+
429+
del step_args['ProcessingOutputConfig']['Outputs'][0]['S3Output']['S3Uri']
430+
del step_def['Arguments']['ProcessingOutputConfig']['Outputs'][0]['S3Output']['S3Uri']
431+
432+
assert step_def == {
391433
"Name": "MyProcessingStep",
392434
"Type": "Processing",
393-
"Arguments": step_args.args,
435+
"Arguments": step_args,
394436
}
395437

396438

0 commit comments

Comments
 (0)