Skip to content

Commit 69ae4bd

Browse files
committed
create an internal helper function to generate output node
1 parent 854dd10 commit 69ae4bd

File tree

1 file changed

+25
-38
lines changed

1 file changed

+25
-38
lines changed

src/sagemaker/wrangler/ingestion.py

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,7 @@ def generate_data_ingestion_flow_from_s3_input(
6363
},
6464
}
6565

66-
output_node = {
67-
"node_id": str(uuid4()),
68-
"type": "TRANSFORM",
69-
"operator": f"sagemaker.spark.infer_and_cast_type_{operator_version}",
70-
"parameters": {},
71-
"inputs": [
72-
{"name": "default", "node_id": source_node["node_id"], "output_name": "default"}
73-
],
74-
"outputs": [{"name": "default"}],
75-
}
76-
77-
if schema:
78-
output_node["trained_parameters"] = schema
66+
output_node = _get_output_node(source_node['node_id'], operator_version, schema)
7967

8068
flow = {
8169
"metadata": {"version": 1, "disable_limits": False},
@@ -122,19 +110,7 @@ def generate_data_ingestion_flow_from_athena_dataset_definition(
122110
},
123111
}
124112

125-
output_node = {
126-
"node_id": str(uuid4()),
127-
"type": "TRANSFORM",
128-
"operator": f"sagemaker.spark.infer_and_cast_type_{operator_version}",
129-
"parameters": {},
130-
"inputs": [
131-
{"name": "default", "node_id": source_node["node_id"], "output_name": "default"}
132-
],
133-
"outputs": [{"name": "default"}],
134-
}
135-
136-
if schema:
137-
output_node["trained_parameters"] = schema
113+
output_node = _get_output_node(source_node['node_id'], operator_version, schema)
138114

139115
flow = {
140116
"metadata": {"version": 1, "disable_limits": False},
@@ -183,23 +159,34 @@ def generate_data_ingestion_flow_from_redshift_dataset_definition(
183159
},
184160
}
185161

186-
output_node = {
162+
output_node = _get_output_node(source_node['node_id'], operator_version, schema)
163+
164+
flow = {
165+
"metadata": {"version": 1, "disable_limits": False},
166+
"nodes": [source_node, output_node]
167+
}
168+
169+
return flow, f'{output_node["node_id"]}.default'
170+
171+
172+
def _get_output_node(source_node_id: str, operator_version: str, schema: Dict):
173+
"""A helper function to generate output node, for internal use only
174+
175+
Args:
176+
source_node_id (str): source node id
177+
operator_version: (str): the version of the operator
178+
schema: (typing.Dict): the schema for the data to be ingested
179+
Returns:
180+
dict (typing.Dict): output node
181+
"""
182+
return {
187183
"node_id": str(uuid4()),
188184
"type": "TRANSFORM",
189185
"operator": f"sagemaker.spark.infer_and_cast_type_{operator_version}",
186+
"trained_parameters": {} if schema is None else schema,
190187
"parameters": {},
191188
"inputs": [
192-
{"name": "default", "node_id": source_node["node_id"], "output_name": "default"}
189+
{"name": "default", "node_id": source_node_id, "output_name": "default"}
193190
],
194191
"outputs": [{"name": "default"}],
195192
}
196-
197-
if schema:
198-
output_node["trained_parameters"] = schema
199-
200-
flow = {
201-
"metadata": {"version": 1, "disable_limits": False},
202-
"nodes": [source_node, output_node],
203-
}
204-
205-
return flow, f'{output_node["node_id"]}.default'

0 commit comments

Comments
 (0)