Skip to content

Commit e501593

Browse files
committed
add DatasetDefinition input in pipeline integ test
1 parent fd29b54 commit e501593

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

tests/integ/test_workflow.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sagemaker.sklearn.processing import SKLearnProcessor
3232
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo
3333
from sagemaker.workflow.condition_step import ConditionStep
34+
from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition
3435
from sagemaker.workflow.parameters import (
3536
ParameterInteger,
3637
ParameterString,
@@ -97,8 +98,32 @@ def pipeline_name():
9798
return f"my-pipeline-{int(time.time() * 10**7)}"
9899

99100

101+
@pytest.fixture
102+
def athena_dataset_definition(sagemaker_session):
103+
return DatasetDefinition(
104+
local_path="/opt/ml/processing/input/add",
105+
data_distribution_type="FullyReplicated",
106+
input_mode="File",
107+
athena_dataset_definition=AthenaDatasetDefinition(
108+
catalog="AwsDataCatalog",
109+
database="default",
110+
work_group="workgroup",
111+
query_string='SELECT * FROM "default"."s3_test_table_$STAGE_$REGIONUNDERSCORED";',
112+
output_s3_uri=f"s3://{sagemaker_session.default_bucket()}/add",
113+
output_format="JSON",
114+
output_compression="GZIP",
115+
),
116+
)
117+
118+
100119
def test_three_step_definition(
101-
sagemaker_session, workflow_session, region_name, role, script_dir, pipeline_name
120+
sagemaker_session,
121+
workflow_session,
122+
region_name,
123+
role,
124+
script_dir,
125+
pipeline_name,
126+
athena_dataset_definition,
102127
):
103128
framework_version = "0.20.0"
104129
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
@@ -117,7 +142,10 @@ def test_three_step_definition(
117142
step_process = ProcessingStep(
118143
name="my-process",
119144
processor=sklearn_processor,
120-
inputs=[ProcessingInput(source=input_data, destination="/opt/ml/processing/input")],
145+
inputs=[
146+
ProcessingInput(source=input_data, destination="/opt/ml/processing/input"),
147+
ProcessingInput(dataset_definition=athena_dataset_definition),
148+
],
121149
outputs=[
122150
ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"),
123151
ProcessingOutput(output_name="test_data", source="/opt/ml/processing/test"),
@@ -228,11 +256,15 @@ def test_one_step_sklearn_processing_pipeline(
228256
cpu_instance_type,
229257
pipeline_name,
230258
region,
259+
athena_dataset_definition,
231260
):
232261
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
233262
script_path = os.path.join(DATA_DIR, "dummy_script.py")
234263
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
235-
inputs = [ProcessingInput(source=input_file_path, destination="/opt/ml/processing/inputs/")]
264+
inputs = [
265+
ProcessingInput(source=input_file_path, destination="/opt/ml/processing/inputs/"),
266+
ProcessingInput(dataset_definition=athena_dataset_definition),
267+
]
236268

237269
sklearn_processor = SKLearnProcessor(
238270
framework_version=sklearn_latest_version,

0 commit comments

Comments
 (0)