Skip to content

Commit 854dd10

Browse files
committed
separate flow generation by source input type + move generation helpers to sagemaker.wrangler.ingestion
1 parent 21bedbb commit 854dd10

File tree

6 files changed

+391
-419
lines changed

6 files changed

+391
-419
lines changed

src/sagemaker/workflow/utilities.py

Lines changed: 1 addition & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,13 @@
1313
"""Utilities to support workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List, Sequence, Union, Dict
16+
from typing import List, Sequence, Union
1717

1818
from sagemaker.workflow.entities import (
1919
Entity,
2020
RequestType,
2121
)
2222
from sagemaker.workflow.step_collections import StepCollection
23-
from sagemaker.dataset_definition.inputs import (
24-
RedshiftDatasetDefinition,
25-
AthenaDatasetDefinition,
26-
)
27-
from uuid import uuid4
2823

2924

3025
def list_to_request(entities: Sequence[Union[Entity, StepCollection]]) -> List[RequestType]:
@@ -42,105 +37,3 @@ def list_to_request(entities: Sequence[Union[Entity, StepCollection]]) -> List[R
4237
elif isinstance(entity, StepCollection):
4338
request_dicts.extend(entity.request_dicts())
4439
return request_dicts
45-
46-
47-
def generate_data_ingestion_flow_recipe(
48-
input_name: str,
49-
s3_uri: str = None,
50-
s3_content_type: str = "csv",
51-
s3_has_header: bool = False,
52-
athena_dataset_definition: AthenaDatasetDefinition = None,
53-
redshift_dataset_definition: RedshiftDatasetDefinition = None,
54-
) -> Dict:
55-
"""Generate the data ingestion only flow recipe
56-
57-
Args:
58-
input_name (str): s3 input to recipe source node
59-
s3_uri (str): s3 input uri
60-
s3_content_type (str): s3 input content type
61-
s3_has_header (bool): flag indicating the input has header or not
62-
athena_dataset_definition (AthenaDatasetDefinition): athena input to recipe source node
63-
redshift_dataset_definition (RedshiftDatasetDefinition): redshift input to recipe source node
64-
Returns:
65-
dict: A flow recipe only conduct data ingestion with 1-1 mapping
66-
"""
67-
if s3_uri is None and athena_dataset_definition is None and redshift_dataset_definition is None:
68-
raise ValueError("One of s3 input, athena dataset definition, or redshift dataset definition need to be given.")
69-
70-
recipe = {"metadata": {"version": 1, "disable_limits": False}, "nodes": []}
71-
72-
source_node = {
73-
"node_id": str(uuid4()),
74-
"type": "SOURCE",
75-
"inputs": [],
76-
"outputs": [
77-
{
78-
"name": "default",
79-
"sampling": {"sampling_method": "sample_by_limit", "limit_rows": 50000},
80-
}
81-
],
82-
}
83-
84-
input_definition = None
85-
operator = None
86-
87-
if s3_uri is not None:
88-
operator = "sagemaker.s3_source_0.1"
89-
input_definition = {
90-
"__typename": "S3CreateDatasetDefinitionOutput",
91-
"datasetSourceType": "S3",
92-
"name": input_name,
93-
"description": None,
94-
"s3ExecutionContext": {
95-
"__typename": "S3ExecutionContext",
96-
"s3Uri": s3_uri,
97-
"s3ContentType": s3_content_type,
98-
"s3HasHeader": s3_has_header,
99-
},
100-
}
101-
102-
if input_definition is None and athena_dataset_definition is not None:
103-
operator = "sagemaker.athena_source_0.1"
104-
input_definition = {
105-
"datasetSourceType": "Athena",
106-
"name": input_name,
107-
"catalogName": athena_dataset_definition.catalog,
108-
"databaseName": athena_dataset_definition.database,
109-
"queryString": athena_dataset_definition.query_string,
110-
"s3OutputLocation": athena_dataset_definition.output_s3_uri,
111-
"outputFormat": athena_dataset_definition.output_format,
112-
}
113-
114-
if input_definition is None and redshift_dataset_definition is not None:
115-
operator = "sagemaker.redshift_source_0.1"
116-
input_definition = {
117-
"datasetSourceType": "Redshift",
118-
"name": input_name,
119-
"clusterIdentifier": redshift_dataset_definition.cluster_id,
120-
"database": redshift_dataset_definition.database,
121-
"dbUser": redshift_dataset_definition.db_user,
122-
"queryString": redshift_dataset_definition.query_string,
123-
"unloadIamRole": redshift_dataset_definition.cluster_role_arn,
124-
"s3OutputLocation": redshift_dataset_definition.output_s3_uri,
125-
"outputFormat": redshift_dataset_definition.output_format,
126-
}
127-
128-
source_node["operator"] = operator
129-
source_node["parameters"] = {"dataset_definition": input_definition}
130-
131-
recipe["nodes"].append(source_node)
132-
133-
type_infer_and_cast_node = {
134-
"node_id": str(uuid4()),
135-
"type": "TRANSFORM",
136-
"operator": "sagemaker.spark.infer_and_cast_type_0.1",
137-
"parameters": {},
138-
"inputs": [
139-
{"name": "default", "node_id": source_node["node_id"], "output_name": "default"}
140-
],
141-
"outputs": [{"name": "default"}],
142-
}
143-
144-
recipe["nodes"].append(type_infer_and_cast_node)
145-
146-
return recipe

src/sagemaker/wrangler/__init__.py

Whitespace-only changes.

src/sagemaker/wrangler/ingestion.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
# language governing permissions and limitations under the License.
14+
"""Data wrangler helpers for data ingestion."""
15+
from __future__ import absolute_import
16+
17+
from typing import Dict
18+
from uuid import uuid4
19+
from sagemaker.dataset_definition.inputs import (
20+
RedshiftDatasetDefinition,
21+
AthenaDatasetDefinition,
22+
)
23+
24+
25+
def generate_data_ingestion_flow_from_s3_input(
26+
input_name: str,
27+
s3_uri: str,
28+
s3_content_type: str = "csv",
29+
s3_has_header: bool = False,
30+
operator_version: str = "0.1",
31+
schema: Dict = None,
32+
):
33+
"""Generate the data ingestion only flow from s3 input
34+
35+
Args:
36+
input_name (str): the name of the input to flow source node
37+
s3_uri (str): uri for the s3 input to flow source node
38+
s3_content_type (str): s3 input content type
39+
s3_has_header (bool): flag indicating the input has header or not
40+
operator_version: (str): the version of the operator
41+
schema: (typing.Dict): the schema for the data to be ingested
42+
Returns:
43+
dict (typing.Dict): A flow only conduct data ingestion with 1-1 mapping
44+
output_name (str): The output name used to configure
45+
`sagemaker.processing.FeatureStoreOutput`
46+
"""
47+
source_node = {
48+
"node_id": str(uuid4()),
49+
"type": "SOURCE",
50+
"inputs": [],
51+
"outputs": [{"name": "default"}],
52+
"operator": f"sagemaker.s3_source_{operator_version}",
53+
"parameters": {
54+
"dataset_definition": {
55+
"datasetSourceType": "S3",
56+
"name": input_name,
57+
"s3ExecutionContext": {
58+
"s3Uri": s3_uri,
59+
"s3ContentType": s3_content_type,
60+
"s3HasHeader": s3_has_header,
61+
},
62+
}
63+
},
64+
}
65+
66+
output_node = {
67+
"node_id": str(uuid4()),
68+
"type": "TRANSFORM",
69+
"operator": f"sagemaker.spark.infer_and_cast_type_{operator_version}",
70+
"parameters": {},
71+
"inputs": [
72+
{"name": "default", "node_id": source_node["node_id"], "output_name": "default"}
73+
],
74+
"outputs": [{"name": "default"}],
75+
}
76+
77+
if schema:
78+
output_node["trained_parameters"] = schema
79+
80+
flow = {
81+
"metadata": {"version": 1, "disable_limits": False},
82+
"nodes": [source_node, output_node],
83+
}
84+
85+
return flow, f'{output_node["node_id"]}.default'
86+
87+
88+
def generate_data_ingestion_flow_from_athena_dataset_definition(
89+
input_name: str,
90+
athena_dataset_definition: AthenaDatasetDefinition,
91+
operator_version: str = "0.1",
92+
schema: Dict = None,
93+
):
94+
"""Generate the data ingestion only flow from athena input
95+
96+
Args:
97+
input_name (str): the name of the input to flow source node
98+
athena_dataset_definition (AthenaDatasetDefinition): athena input to flow source node
99+
operator_version: (str): the version of the operator
100+
schema: (typing.Dict): the schema for the data to be ingested
101+
Returns:
102+
dict (typing.Dict): A flow only conduct data ingestion with 1-1 mapping
103+
output_name (str): The output name used to configure
104+
`sagemaker.processing.FeatureStoreOutput`
105+
"""
106+
source_node = {
107+
"node_id": str(uuid4()),
108+
"type": "SOURCE",
109+
"inputs": [],
110+
"outputs": [{"name": "default"}],
111+
"operator": f"sagemaker.athena_source_{operator_version}",
112+
"parameters": {
113+
"dataset_definition": {
114+
"datasetSourceType": "Athena",
115+
"name": input_name,
116+
"catalogName": athena_dataset_definition.catalog,
117+
"databaseName": athena_dataset_definition.database,
118+
"queryString": athena_dataset_definition.query_string,
119+
"s3OutputLocation": athena_dataset_definition.output_s3_uri,
120+
"outputFormat": athena_dataset_definition.output_format,
121+
}
122+
},
123+
}
124+
125+
output_node = {
126+
"node_id": str(uuid4()),
127+
"type": "TRANSFORM",
128+
"operator": f"sagemaker.spark.infer_and_cast_type_{operator_version}",
129+
"parameters": {},
130+
"inputs": [
131+
{"name": "default", "node_id": source_node["node_id"], "output_name": "default"}
132+
],
133+
"outputs": [{"name": "default"}],
134+
}
135+
136+
if schema:
137+
output_node["trained_parameters"] = schema
138+
139+
flow = {
140+
"metadata": {"version": 1, "disable_limits": False},
141+
"nodes": [source_node, output_node],
142+
}
143+
144+
return flow, f'{output_node["node_id"]}.default'
145+
146+
147+
def generate_data_ingestion_flow_from_redshift_dataset_definition(
148+
input_name: str,
149+
redshift_dataset_definition: RedshiftDatasetDefinition,
150+
operator_version: str = "0.1",
151+
schema: Dict = None,
152+
):
153+
"""Generate the data ingestion only flow from redshift input
154+
155+
Args:
156+
input_name (str): the name of the input to flow source node
157+
redshift_dataset_definition (RedshiftDatasetDefinition): redshift input to flow source node
158+
operator_version: (str): the version of the operator
159+
schema: (typing.Dict): the schema for the data to be ingested
160+
Returns:
161+
dict (typing.Dict): A flow only conduct data ingestion with 1-1 mapping
162+
output_name (str): The output name used to configure
163+
`sagemaker.processing.FeatureStoreOutput`
164+
"""
165+
source_node = {
166+
"node_id": str(uuid4()),
167+
"type": "SOURCE",
168+
"inputs": [],
169+
"outputs": [{"name": "default"}],
170+
"operator": f"sagemaker.redshift_source_{operator_version}",
171+
"parameters": {
172+
"dataset_definition": {
173+
"datasetSourceType": "Redshift",
174+
"name": input_name,
175+
"clusterIdentifier": redshift_dataset_definition.cluster_id,
176+
"database": redshift_dataset_definition.database,
177+
"dbUser": redshift_dataset_definition.db_user,
178+
"queryString": redshift_dataset_definition.query_string,
179+
"unloadIamRole": redshift_dataset_definition.cluster_role_arn,
180+
"s3OutputLocation": redshift_dataset_definition.output_s3_uri,
181+
"outputFormat": redshift_dataset_definition.output_format,
182+
}
183+
},
184+
}
185+
186+
output_node = {
187+
"node_id": str(uuid4()),
188+
"type": "TRANSFORM",
189+
"operator": f"sagemaker.spark.infer_and_cast_type_{operator_version}",
190+
"parameters": {},
191+
"inputs": [
192+
{"name": "default", "node_id": source_node["node_id"], "output_name": "default"}
193+
],
194+
"outputs": [{"name": "default"}],
195+
}
196+
197+
if schema:
198+
output_node["trained_parameters"] = schema
199+
200+
flow = {
201+
"metadata": {"version": 1, "disable_limits": False},
202+
"nodes": [source_node, output_node],
203+
}
204+
205+
return flow, f'{output_node["node_id"]}.default'

0 commit comments

Comments
 (0)