31
31
from sagemaker .sklearn .processing import SKLearnProcessor
32
32
from sagemaker .workflow .conditions import ConditionGreaterThanOrEqualTo
33
33
from sagemaker .workflow .condition_step import ConditionStep
34
+ from sagemaker .dataset_definition .inputs import DatasetDefinition , AthenaDatasetDefinition
34
35
from sagemaker .workflow .parameters import (
35
36
ParameterInteger ,
36
37
ParameterString ,
@@ -97,8 +98,32 @@ def pipeline_name():
97
98
return f"my-pipeline-{ int (time .time () * 10 ** 7 )} "
98
99
99
100
101
+ @pytest .fixture
102
+ def athena_dataset_definition (sagemaker_session ):
103
+ return DatasetDefinition (
104
+ local_path = "/opt/ml/processing/input/add" ,
105
+ data_distribution_type = "FullyReplicated" ,
106
+ input_mode = "File" ,
107
+ athena_dataset_definition = AthenaDatasetDefinition (
108
+ catalog = "AwsDataCatalog" ,
109
+ database = "default" ,
110
+ work_group = "workgroup" ,
111
+ query_string = 'SELECT * FROM "default"."s3_test_table_$STAGE_$REGIONUNDERSCORED";' ,
112
+ output_s3_uri = f"s3://{ sagemaker_session .default_bucket ()} /add" ,
113
+ output_format = "JSON" ,
114
+ output_compression = "GZIP" ,
115
+ ),
116
+ )
117
+
118
+
100
119
def test_three_step_definition (
101
- sagemaker_session , workflow_session , region_name , role , script_dir , pipeline_name
120
+ sagemaker_session ,
121
+ workflow_session ,
122
+ region_name ,
123
+ role ,
124
+ script_dir ,
125
+ pipeline_name ,
126
+ athena_dataset_definition ,
102
127
):
103
128
framework_version = "0.20.0"
104
129
instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
@@ -117,7 +142,10 @@ def test_three_step_definition(
117
142
step_process = ProcessingStep (
118
143
name = "my-process" ,
119
144
processor = sklearn_processor ,
120
- inputs = [ProcessingInput (source = input_data , destination = "/opt/ml/processing/input" )],
145
+ inputs = [
146
+ ProcessingInput (source = input_data , destination = "/opt/ml/processing/input" ),
147
+ ProcessingInput (dataset_definition = athena_dataset_definition ),
148
+ ],
121
149
outputs = [
122
150
ProcessingOutput (output_name = "train_data" , source = "/opt/ml/processing/train" ),
123
151
ProcessingOutput (output_name = "test_data" , source = "/opt/ml/processing/test" ),
@@ -228,11 +256,15 @@ def test_one_step_sklearn_processing_pipeline(
228
256
cpu_instance_type ,
229
257
pipeline_name ,
230
258
region ,
259
+ athena_dataset_definition ,
231
260
):
232
261
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
233
262
script_path = os .path .join (DATA_DIR , "dummy_script.py" )
234
263
input_file_path = os .path .join (DATA_DIR , "dummy_input.txt" )
235
- inputs = [ProcessingInput (source = input_file_path , destination = "/opt/ml/processing/inputs/" )]
264
+ inputs = [
265
+ ProcessingInput (source = input_file_path , destination = "/opt/ml/processing/inputs/" ),
266
+ ProcessingInput (dataset_definition = athena_dataset_definition ),
267
+ ]
236
268
237
269
sklearn_processor = SKLearnProcessor (
238
270
framework_version = sklearn_latest_version ,
0 commit comments