24
24
from sagemaker .tuner import HyperparameterTuner
25
25
from sagemaker .workflow .pipeline_context import PipelineSession
26
26
27
- from sagemaker .processing import Processor , ScriptProcessor , FrameworkProcessor
27
+ from sagemaker .processing import Processor , ScriptProcessor , FrameworkProcessor , ProcessingOutput , ProcessingInput
28
28
from sagemaker .sklearn .processing import SKLearnProcessor
29
29
from sagemaker .pytorch .processing import PyTorchProcessor
30
30
from sagemaker .tensorflow .processing import TensorFlowProcessor
34
34
from sagemaker .wrangler .processing import DataWranglerProcessor
35
35
from sagemaker .spark .processing import SparkJarProcessor , PySparkProcessor
36
36
37
- from sagemaker .processing import ProcessingInput
38
37
39
38
from sagemaker .workflow .steps import CacheConfig , ProcessingStep
40
39
from sagemaker .workflow .pipeline import Pipeline
41
40
from sagemaker .workflow .properties import PropertyFile
41
+ from sagemaker .workflow .parameters import ParameterString
42
+ from sagemaker .workflow .functions import Join
42
43
43
44
from sagemaker .network import NetworkConfig
44
45
from sagemaker .pytorch .estimator import PyTorch
62
63
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
63
64
INSTANCE_TYPE = "ml.m4.xlarge"
64
65
66
+ FRAMEWORK_PROCESSOR = [
67
+ (
68
+ FrameworkProcessor (
69
+ framework_version = "1.8" ,
70
+ instance_type = INSTANCE_TYPE ,
71
+ instance_count = 1 ,
72
+ role = ROLE ,
73
+ estimator_cls = PyTorch ,
74
+ ),
75
+ {"code" : DUMMY_S3_SCRIPT_PATH },
76
+ ),
77
+ (
78
+ SKLearnProcessor (
79
+ framework_version = "0.23-1" ,
80
+ instance_type = INSTANCE_TYPE ,
81
+ instance_count = 1 ,
82
+ role = ROLE ,
83
+ ),
84
+ {"code" : DUMMY_S3_SCRIPT_PATH },
85
+ ),
86
+ (
87
+ PyTorchProcessor (
88
+ role = ROLE ,
89
+ instance_type = INSTANCE_TYPE ,
90
+ instance_count = 1 ,
91
+ framework_version = "1.8.0" ,
92
+ py_version = "py3" ,
93
+ ),
94
+ {"code" : DUMMY_S3_SCRIPT_PATH },
95
+ ),
96
+ (
97
+ TensorFlowProcessor (
98
+ role = ROLE ,
99
+ instance_type = INSTANCE_TYPE ,
100
+ instance_count = 1 ,
101
+ framework_version = "2.0" ,
102
+ ),
103
+ {"code" : DUMMY_S3_SCRIPT_PATH },
104
+ ),
105
+ (
106
+ HuggingFaceProcessor (
107
+ transformers_version = "4.6" ,
108
+ pytorch_version = "1.7" ,
109
+ role = ROLE ,
110
+ instance_count = 1 ,
111
+ instance_type = "ml.p3.2xlarge" ,
112
+ ),
113
+ {"code" : DUMMY_S3_SCRIPT_PATH },
114
+ ),
115
+ (
116
+ XGBoostProcessor (
117
+ framework_version = "1.3-1" ,
118
+ py_version = "py3" ,
119
+ role = ROLE ,
120
+ instance_count = 1 ,
121
+ instance_type = INSTANCE_TYPE ,
122
+ base_job_name = "test-xgboost" ,
123
+ ),
124
+ {"code" : DUMMY_S3_SCRIPT_PATH },
125
+ ),
126
+ (
127
+ MXNetProcessor (
128
+ framework_version = "1.4.1" ,
129
+ py_version = "py3" ,
130
+ role = ROLE ,
131
+ instance_count = 1 ,
132
+ instance_type = INSTANCE_TYPE ,
133
+ base_job_name = "test-mxnet" ,
134
+ ),
135
+ {"code" : DUMMY_S3_SCRIPT_PATH },
136
+ ),
137
+ (
138
+ DataWranglerProcessor (
139
+ role = ROLE ,
140
+ data_wrangler_flow_source = "s3://my-bucket/dw.flow" ,
141
+ instance_count = 1 ,
142
+ instance_type = INSTANCE_TYPE ,
143
+ ),
144
+ {},
145
+ ),
146
+ (
147
+ SparkJarProcessor (
148
+ role = ROLE ,
149
+ framework_version = "2.4" ,
150
+ instance_count = 1 ,
151
+ instance_type = INSTANCE_TYPE ,
152
+ ),
153
+ {
154
+ "submit_app" : "s3://my-jar" ,
155
+ "submit_class" : "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp" ,
156
+ "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
157
+ },
158
+ ),
159
+ (
160
+ PySparkProcessor (
161
+ role = ROLE ,
162
+ framework_version = "2.4" ,
163
+ instance_count = 1 ,
164
+ instance_type = INSTANCE_TYPE ,
165
+ ),
166
+ {
167
+ "submit_app" : "s3://my-jar" ,
168
+ "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
169
+ },
170
+ ),
171
+ ]
172
+
173
+ PROCESSING_INPUT = [
174
+ ProcessingInput (source = f"s3://my-bucket/processing_manifest" , destination = "processing_manifest" ),
175
+ ProcessingInput (
176
+ source = ParameterString (name = "my-processing-input" ),
177
+ destination = "processing-input" ,
178
+ ),
179
+ ProcessingInput (
180
+ source = ParameterString (name = "my-processing-input" , default_value = "s3://my-bucket/my-processing" ),
181
+ destination = "processing-input" ,
182
+ ),
183
+ ProcessingInput (
184
+ source = Join (on = "/" , values = ["s3://my-bucket" , "my-input" ]),
185
+ destination = "processing-input" ,
186
+ )
187
+ ]
188
+
189
+ PROCESSING_OUTPUT = [
190
+ ProcessingOutput (source = "/opt/ml/output" , destination = "s3://my-bucket/my-output" ),
191
+ ProcessingOutput (source = "/opt/ml/output" , destination = ParameterString (name = "my-output" )),
192
+ ProcessingOutput (source = "/opt/ml/output" ,
193
+ destination = ParameterString (name = "my-output" , default_value = "s3://my-bucket/my-output" )),
194
+ ProcessingOutput (
195
+ source = "/opt/ml/output" ,
196
+ destination = Join (on = "/" , values = ["s3://my-bucket" , "my-output" ]),
197
+ )
198
+ ]
199
+
65
200
66
201
@pytest .fixture
67
202
def client ():
@@ -253,117 +388,11 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu
253
388
}
254
389
255
390
256
- @pytest .mark .parametrize (
257
- "framework_processor" ,
258
- [
259
- (
260
- FrameworkProcessor (
261
- framework_version = "1.8" ,
262
- instance_type = INSTANCE_TYPE ,
263
- instance_count = 1 ,
264
- role = ROLE ,
265
- estimator_cls = PyTorch ,
266
- ),
267
- {"code" : DUMMY_S3_SCRIPT_PATH },
268
- ),
269
- (
270
- SKLearnProcessor (
271
- framework_version = "0.23-1" ,
272
- instance_type = INSTANCE_TYPE ,
273
- instance_count = 1 ,
274
- role = ROLE ,
275
- ),
276
- {"code" : DUMMY_S3_SCRIPT_PATH },
277
- ),
278
- (
279
- PyTorchProcessor (
280
- role = ROLE ,
281
- instance_type = INSTANCE_TYPE ,
282
- instance_count = 1 ,
283
- framework_version = "1.8.0" ,
284
- py_version = "py3" ,
285
- ),
286
- {"code" : DUMMY_S3_SCRIPT_PATH },
287
- ),
288
- (
289
- TensorFlowProcessor (
290
- role = ROLE ,
291
- instance_type = INSTANCE_TYPE ,
292
- instance_count = 1 ,
293
- framework_version = "2.0" ,
294
- ),
295
- {"code" : DUMMY_S3_SCRIPT_PATH },
296
- ),
297
- (
298
- HuggingFaceProcessor (
299
- transformers_version = "4.6" ,
300
- pytorch_version = "1.7" ,
301
- role = ROLE ,
302
- instance_count = 1 ,
303
- instance_type = "ml.p3.2xlarge" ,
304
- ),
305
- {"code" : DUMMY_S3_SCRIPT_PATH },
306
- ),
307
- (
308
- XGBoostProcessor (
309
- framework_version = "1.3-1" ,
310
- py_version = "py3" ,
311
- role = ROLE ,
312
- instance_count = 1 ,
313
- instance_type = INSTANCE_TYPE ,
314
- base_job_name = "test-xgboost" ,
315
- ),
316
- {"code" : DUMMY_S3_SCRIPT_PATH },
317
- ),
318
- (
319
- MXNetProcessor (
320
- framework_version = "1.4.1" ,
321
- py_version = "py3" ,
322
- role = ROLE ,
323
- instance_count = 1 ,
324
- instance_type = INSTANCE_TYPE ,
325
- base_job_name = "test-mxnet" ,
326
- ),
327
- {"code" : DUMMY_S3_SCRIPT_PATH },
328
- ),
329
- (
330
- DataWranglerProcessor (
331
- role = ROLE ,
332
- data_wrangler_flow_source = f"s3://{ BUCKET } /dw.flow" ,
333
- instance_count = 1 ,
334
- instance_type = INSTANCE_TYPE ,
335
- ),
336
- {},
337
- ),
338
- (
339
- SparkJarProcessor (
340
- role = ROLE ,
341
- framework_version = "2.4" ,
342
- instance_count = 1 ,
343
- instance_type = INSTANCE_TYPE ,
344
- ),
345
- {
346
- "submit_app" : "s3://my-jar" ,
347
- "submit_class" : "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp" ,
348
- "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
349
- },
350
- ),
351
- (
352
- PySparkProcessor (
353
- role = ROLE ,
354
- framework_version = "2.4" ,
355
- instance_count = 1 ,
356
- instance_type = INSTANCE_TYPE ,
357
- ),
358
- {
359
- "submit_app" : "s3://my-jar" ,
360
- "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
361
- },
362
- ),
363
- ],
364
- )
391
+ @pytest .mark .parametrize ("framework_processor" , FRAMEWORK_PROCESSOR )
392
+ @pytest .mark .parametrize ("processing_input" , PROCESSING_INPUT )
393
+ @pytest .mark .parametrize ("processing_output" , PROCESSING_OUTPUT )
365
394
def test_processing_step_with_framework_processor (
366
- framework_processor , pipeline_session , processing_input , network_config
395
+ framework_processor , pipeline_session , processing_input , processing_output , network_config
367
396
):
368
397
369
398
processor , run_inputs = framework_processor
@@ -373,7 +402,8 @@ def test_processing_step_with_framework_processor(
373
402
processor .volume_kms_key = "volume-kms-key"
374
403
processor .network_config = network_config
375
404
376
- run_inputs ["inputs" ] = processing_input
405
+ run_inputs ["inputs" ] = [processing_input ]
406
+ run_inputs ["outputs" ] = [processing_output ]
377
407
378
408
step_args = processor .run (** run_inputs )
379
409
@@ -387,10 +417,22 @@ def test_processing_step_with_framework_processor(
387
417
sagemaker_session = pipeline_session ,
388
418
)
389
419
390
- assert json .loads (pipeline .definition ())["Steps" ][0 ] == {
420
+ step_args = step_args .args
421
+ step_def = json .loads (pipeline .definition ())["Steps" ][0 ]
422
+
423
+ assert step_args ['ProcessingInputs' ][0 ]['S3Input' ]['S3Uri' ] == processing_input .source
424
+ assert step_args ['ProcessingOutputConfig' ]['Outputs' ][0 ]['S3Output' ]['S3Uri' ] == processing_output .destination
425
+
426
+ del step_args ['ProcessingInputs' ][0 ]['S3Input' ]['S3Uri' ]
427
+ del step_def ['Arguments' ]['ProcessingInputs' ][0 ]['S3Input' ]['S3Uri' ]
428
+
429
+ del step_args ['ProcessingOutputConfig' ]['Outputs' ][0 ]['S3Output' ]['S3Uri' ]
430
+ del step_def ['Arguments' ]['ProcessingOutputConfig' ]['Outputs' ][0 ]['S3Output' ]['S3Uri' ]
431
+
432
+ assert step_def == {
391
433
"Name" : "MyProcessingStep" ,
392
434
"Type" : "Processing" ,
393
- "Arguments" : step_args . args ,
435
+ "Arguments" : step_args ,
394
436
}
395
437
396
438
0 commit comments