28
28
rule_configs ,
29
29
)
30
30
from datetime import datetime
31
+ from sagemaker import image_uris
31
32
from sagemaker .inputs import CreateModelInput , TrainingInput
32
33
from sagemaker .model import Model
33
34
from sagemaker .processing import ProcessingInput , ProcessingOutput
39
40
from sagemaker .spark .processing import PySparkProcessor , SparkJarProcessor
40
41
from sagemaker .workflow .conditions import ConditionGreaterThanOrEqualTo
41
42
from sagemaker .workflow .condition_step import ConditionStep
43
+ from sagemaker .workflow .processing import DataWranglerProcessor
42
44
from sagemaker .dataset_definition .inputs import DatasetDefinition , AthenaDatasetDefinition
43
45
from sagemaker .workflow .execution_variables import ExecutionVariables
44
46
from sagemaker .workflow .functions import Join
@@ -84,7 +86,7 @@ def script_dir():
84
86
85
87
@pytest .fixture
86
88
def pipeline_name ():
87
- return f"my-pipeline-{ int (time .time () * 10 ** 7 )} "
89
+ return f"my-pipeline-{ int (time .time () * 10 ** 7 )} "
88
90
89
91
90
92
@pytest .fixture
@@ -228,12 +230,12 @@ def build_jar():
228
230
229
231
230
232
def test_three_step_definition (
231
- sagemaker_session ,
232
- region_name ,
233
- role ,
234
- script_dir ,
235
- pipeline_name ,
236
- athena_dataset_definition ,
233
+ sagemaker_session ,
234
+ region_name ,
235
+ role ,
236
+ script_dir ,
237
+ pipeline_name ,
238
+ athena_dataset_definition ,
237
239
):
238
240
framework_version = "0.20.0"
239
241
instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
@@ -385,13 +387,13 @@ def test_three_step_definition(
385
387
386
388
387
389
def test_one_step_sklearn_processing_pipeline (
388
- sagemaker_session ,
389
- role ,
390
- sklearn_latest_version ,
391
- cpu_instance_type ,
392
- pipeline_name ,
393
- region_name ,
394
- athena_dataset_definition ,
390
+ sagemaker_session ,
391
+ role ,
392
+ sklearn_latest_version ,
393
+ cpu_instance_type ,
394
+ pipeline_name ,
395
+ region_name ,
396
+ athena_dataset_definition ,
395
397
):
396
398
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
397
399
script_path = os .path .join (DATA_DIR , "dummy_script.py" )
@@ -478,11 +480,11 @@ def test_one_step_sklearn_processing_pipeline(
478
480
479
481
480
482
def test_one_step_pyspark_processing_pipeline (
481
- sagemaker_session ,
482
- role ,
483
- cpu_instance_type ,
484
- pipeline_name ,
485
- region_name ,
483
+ sagemaker_session ,
484
+ role ,
485
+ cpu_instance_type ,
486
+ pipeline_name ,
487
+ region_name ,
486
488
):
487
489
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
488
490
script_path = os .path .join (DATA_DIR , "dummy_script.py" )
@@ -580,7 +582,7 @@ def test_one_step_pyspark_processing_pipeline(
580
582
581
583
582
584
def test_one_step_sparkjar_processing_pipeline (
583
- sagemaker_session , role , cpu_instance_type , pipeline_name , region_name , configuration , build_jar
585
+ sagemaker_session , role , cpu_instance_type , pipeline_name , region_name , configuration , build_jar
584
586
):
585
587
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
586
588
cache_config = CacheConfig (enable_caching = True , expire_after = "T30m" )
@@ -677,11 +679,11 @@ def test_one_step_sparkjar_processing_pipeline(
677
679
678
680
679
681
def test_conditional_pytorch_training_model_registration (
680
- sagemaker_session ,
681
- role ,
682
- cpu_instance_type ,
683
- pipeline_name ,
684
- region_name ,
682
+ sagemaker_session ,
683
+ role ,
684
+ cpu_instance_type ,
685
+ pipeline_name ,
686
+ region_name ,
685
687
):
686
688
base_dir = os .path .join (DATA_DIR , "pytorch_mnist" )
687
689
entry_point = os .path .join (base_dir , "mnist.py" )
@@ -777,11 +779,11 @@ def test_conditional_pytorch_training_model_registration(
777
779
778
780
779
781
def test_training_job_with_debugger_and_profiler (
780
- sagemaker_session ,
781
- pipeline_name ,
782
- role ,
783
- pytorch_training_latest_version ,
784
- pytorch_training_latest_py_version ,
782
+ sagemaker_session ,
783
+ pipeline_name ,
784
+ role ,
785
+ pytorch_training_latest_version ,
786
+ pytorch_training_latest_py_version ,
785
787
):
786
788
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
787
789
instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
@@ -858,7 +860,7 @@ def test_training_job_with_debugger_and_profiler(
858
860
assert config ["RuleEvaluatorImage" ] == rule .image_uri
859
861
assert config ["VolumeSizeInGB" ] == 0
860
862
assert (
861
- config ["RuleParameters" ]["rule_to_invoke" ] == rule .rule_parameters ["rule_to_invoke" ]
863
+ config ["RuleParameters" ]["rule_to_invoke" ] == rule .rule_parameters ["rule_to_invoke" ]
862
864
)
863
865
assert job_description ["DebugHookConfig" ] == debugger_hook_config ._to_request_dict ()
864
866
@@ -869,3 +871,78 @@ def test_training_job_with_debugger_and_profiler(
869
871
pipeline .delete ()
870
872
except Exception :
871
873
pass
874
+
875
+
876
+ def test_one_step_data_wrangler_processing_pipeline (
877
+ sagemaker_session , role , cpu_instance_type , pipeline_name , region_name
878
+ ):
879
+ instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
880
+
881
+ recipe_file_path = os .path .join (DATA_DIR , "workflow" , "dummy_recipe.flow" )
882
+ input_file_path = os .path .join (DATA_DIR , "workflow" , "dummy_data.csv" )
883
+
884
+ output_name = "1bd0aaad-9c93-41b2-8d42-58e214f0843f.default"
885
+ output_content_type = "CSV"
886
+ output_config = {output_name : {"content_type" : output_content_type }}
887
+ job_argument = [f"--output-config '{ json .dumps (output_config )} '" ]
888
+
889
+ inputs = [ProcessingInput (input_name = "job_data" , source = input_file_path , destination = "/opt/ml/processing" )]
890
+
891
+ output_s3_uri = f"s3://{ sagemaker_session .default_bucket ()} /output"
892
+ outputs = [
893
+ ProcessingOutput (
894
+ output_name = output_name ,
895
+ source = "/opt/ml/processing/output" ,
896
+ destination = output_s3_uri ,
897
+ s3_upload_mode = "EndOfJob" ,
898
+ )
899
+ ]
900
+
901
+ data_wrangler_processor = DataWranglerProcessor (
902
+ role = role ,
903
+ data_wrangler_recipe_source = recipe_file_path ,
904
+ instance_count = instance_count ,
905
+ instance_type = cpu_instance_type ,
906
+ max_runtime_in_seconds = 86400 ,
907
+ )
908
+
909
+ data_wrangler_step = ProcessingStep (
910
+ name = "data-wrangler-step" ,
911
+ processor = data_wrangler_processor ,
912
+ inputs = inputs ,
913
+ outputs = outputs ,
914
+ job_arguments = job_argument ,
915
+ )
916
+
917
+ pipeline = Pipeline (
918
+ name = pipeline_name ,
919
+ parameters = [instance_count ],
920
+ steps = [data_wrangler_step ],
921
+ sagemaker_session = sagemaker_session ,
922
+ )
923
+
924
+ definition = json .loads (pipeline .definition ())
925
+ expected_image_uri = image_uris .retrieve ("data-wrangler" , region = sagemaker_session .boto_region_name )
926
+ assert len (definition ["Steps" ]) == 1
927
+ assert definition ["Steps" ][0 ]["Arguments" ]["AppSpecification" ]["ImageUri" ] is not None
928
+ assert definition ["Steps" ][0 ]["Arguments" ]["AppSpecification" ]["ImageUri" ] == expected_image_uri
929
+
930
+ assert definition ["Steps" ][0 ]["Arguments" ]["ProcessingInputs" ] is not None
931
+ processing_inputs = definition ["Steps" ][0 ]["Arguments" ]["ProcessingInputs" ]
932
+ assert len (processing_inputs ) == 2
933
+ for processing_input in processing_inputs :
934
+ if processing_input ["InputName" ] == "flow" :
935
+ assert processing_input ["S3Input" ]["S3Uri" ].endswith (".flow" )
936
+ assert processing_input ["S3Input" ]["LocalPath" ] == "/opt/ml/processing/flow"
937
+ elif processing_input ["InputName" ] == "job_data" :
938
+ assert processing_input ["S3Input" ]["S3Uri" ].endswith (".csv" )
939
+ assert processing_input ["S3Input" ]["LocalPath" ] == "/opt/ml/processing"
940
+ else :
941
+ raise AssertionError ("Unknown input name" )
942
+ assert definition ["Steps" ][0 ]["Arguments" ]["ProcessingOutputConfig" ] is not None
943
+ processing_outputs = definition ["Steps" ][0 ]["Arguments" ]["ProcessingOutputConfig" ]["Outputs" ]
944
+ assert len (processing_outputs ) == 1
945
+ assert processing_outputs [0 ]["OutputName" ] == output_name
946
+ assert processing_outputs [0 ]["S3Output" ] is not None
947
+ assert processing_outputs [0 ]["S3Output" ]["LocalPath" ] == "/opt/ml/processing/output"
948
+ assert processing_outputs [0 ]["S3Output" ]["S3Uri" ] == output_s3_uri
0 commit comments