Skip to content

Commit 868db55

Browse files
committed
add helper function to generate no-op (data ingestion only) recipe
1 parent 1a5d9c9 commit 868db55

File tree

2 files changed

+420
-1
lines changed

2 files changed

+420
-1
lines changed

src/sagemaker/workflow/utilities.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,18 @@
1313
"""Utilities to support workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List, Sequence, Union
16+
from typing import List, Sequence, Union, Dict
1717

1818
from sagemaker.workflow.entities import (
1919
Entity,
2020
RequestType,
2121
)
2222
from sagemaker.workflow.step_collections import StepCollection
23+
from sagemaker.dataset_definition.inputs import (
24+
RedshiftDatasetDefinition,
25+
AthenaDatasetDefinition,
26+
)
27+
from uuid import uuid4
2328

2429

2530
def list_to_request(entities: Sequence[Union[Entity, StepCollection]]) -> List[RequestType]:
@@ -37,3 +42,105 @@ def list_to_request(entities: Sequence[Union[Entity, StepCollection]]) -> List[R
3742
elif isinstance(entity, StepCollection):
3843
request_dicts.extend(entity.request_dicts())
3944
return request_dicts
45+
46+
47+
def generate_data_ingestion_flow_recipe(
48+
input_name: str,
49+
s3_uri: str = None,
50+
s3_content_type: str = "csv",
51+
s3_has_header: bool = False,
52+
athena_dataset_definition: AthenaDatasetDefinition = None,
53+
redshift_dataset_definition: RedshiftDatasetDefinition = None,
54+
) -> Dict:
55+
"""Generate the data ingestion only flow recipe
56+
57+
Args:
58+
input_name (str): s3 input to recipe source node
59+
s3_uri (str): s3 input uri
60+
s3_content_type (str): s3 input content type
61+
s3_has_header (bool): flag indicating the input has header or not
62+
athena_dataset_definition (AthenaDatasetDefinition): athena input to recipe source node
63+
redshift_dataset_definition (RedshiftDatasetDefinition): redshift input to recipe source node
64+
Returns:
65+
dict: A flow recipe only conduct data ingestion with 1-1 mapping
66+
"""
67+
if s3_uri is None and athena_dataset_definition is None and redshift_dataset_definition is None:
68+
raise ValueError("One of s3 input, athena dataset definition, or redshift dataset definition need to be given.")
69+
70+
recipe = {"metadata": {"version": 1, "disable_limits": False}, "nodes": []}
71+
72+
source_node = {
73+
"node_id": str(uuid4()),
74+
"type": "SOURCE",
75+
"inputs": [],
76+
"outputs": [
77+
{
78+
"name": "default",
79+
"sampling": {"sampling_method": "sample_by_limit", "limit_rows": 50000},
80+
}
81+
],
82+
}
83+
84+
input_definition = None
85+
operator = None
86+
87+
if s3_uri is not None:
88+
operator = "sagemaker.s3_source_0.1"
89+
input_definition = {
90+
"__typename": "S3CreateDatasetDefinitionOutput",
91+
"datasetSourceType": "S3",
92+
"name": input_name,
93+
"description": None,
94+
"s3ExecutionContext": {
95+
"__typename": "S3ExecutionContext",
96+
"s3Uri": s3_uri,
97+
"s3ContentType": s3_content_type,
98+
"s3HasHeader": s3_has_header,
99+
},
100+
}
101+
102+
if input_definition is None and athena_dataset_definition is not None:
103+
operator = "sagemaker.athena_source_0.1"
104+
input_definition = {
105+
"datasetSourceType": "Athena",
106+
"name": input_name,
107+
"catalogName": athena_dataset_definition.catalog,
108+
"databaseName": athena_dataset_definition.database,
109+
"queryString": athena_dataset_definition.query_string,
110+
"s3OutputLocation": athena_dataset_definition.output_s3_uri,
111+
"outputFormat": athena_dataset_definition.output_format,
112+
}
113+
114+
if input_definition is None and redshift_dataset_definition is not None:
115+
operator = "sagemaker.redshift_source_0.1"
116+
input_definition = {
117+
"datasetSourceType": "Redshift",
118+
"name": input_name,
119+
"clusterIdentifier": redshift_dataset_definition.cluster_id,
120+
"database": redshift_dataset_definition.database,
121+
"dbUser": redshift_dataset_definition.db_user,
122+
"queryString": redshift_dataset_definition.query_string,
123+
"unloadIamRole": redshift_dataset_definition.cluster_role_arn,
124+
"s3OutputLocation": redshift_dataset_definition.output_s3_uri,
125+
"outputFormat": redshift_dataset_definition.output_format,
126+
}
127+
128+
source_node["operator"] = operator
129+
source_node["parameters"] = {"dataset_definition": input_definition}
130+
131+
recipe["nodes"].append(source_node)
132+
133+
type_infer_and_cast_node = {
134+
"node_id": str(uuid4()),
135+
"type": "TRANSFORM",
136+
"operator": "sagemaker.spark.infer_and_cast_type_0.1",
137+
"parameters": {},
138+
"inputs": [
139+
{"name": "default", "node_id": source_node["node_id"], "output_name": "default"}
140+
],
141+
"outputs": [{"name": "default"}],
142+
}
143+
144+
recipe["nodes"].append(type_infer_and_cast_node)
145+
146+
return recipe

0 commit comments

Comments
 (0)