@@ -63,19 +63,7 @@ def generate_data_ingestion_flow_from_s3_input(
63
63
},
64
64
}
65
65
66
- output_node = {
67
- "node_id" : str (uuid4 ()),
68
- "type" : "TRANSFORM" ,
69
- "operator" : f"sagemaker.spark.infer_and_cast_type_{ operator_version } " ,
70
- "parameters" : {},
71
- "inputs" : [
72
- {"name" : "default" , "node_id" : source_node ["node_id" ], "output_name" : "default" }
73
- ],
74
- "outputs" : [{"name" : "default" }],
75
- }
76
-
77
- if schema :
78
- output_node ["trained_parameters" ] = schema
66
+ output_node = _get_output_node (source_node ['node_id' ], operator_version , schema )
79
67
80
68
flow = {
81
69
"metadata" : {"version" : 1 , "disable_limits" : False },
@@ -122,19 +110,7 @@ def generate_data_ingestion_flow_from_athena_dataset_definition(
122
110
},
123
111
}
124
112
125
- output_node = {
126
- "node_id" : str (uuid4 ()),
127
- "type" : "TRANSFORM" ,
128
- "operator" : f"sagemaker.spark.infer_and_cast_type_{ operator_version } " ,
129
- "parameters" : {},
130
- "inputs" : [
131
- {"name" : "default" , "node_id" : source_node ["node_id" ], "output_name" : "default" }
132
- ],
133
- "outputs" : [{"name" : "default" }],
134
- }
135
-
136
- if schema :
137
- output_node ["trained_parameters" ] = schema
113
+ output_node = _get_output_node (source_node ['node_id' ], operator_version , schema )
138
114
139
115
flow = {
140
116
"metadata" : {"version" : 1 , "disable_limits" : False },
@@ -183,23 +159,34 @@ def generate_data_ingestion_flow_from_redshift_dataset_definition(
183
159
},
184
160
}
185
161
186
- output_node = {
162
+ output_node = _get_output_node (source_node ['node_id' ], operator_version , schema )
163
+
164
+ flow = {
165
+ "metadata" : {"version" : 1 , "disable_limits" : False },
166
+ "nodes" : [source_node , output_node ]
167
+ }
168
+
169
+ return flow , f'{ output_node ["node_id" ]} .default'
170
+
171
+
172
+ def _get_output_node (source_node_id : str , operator_version : str , schema : Dict ):
173
+ """A helper function to generate output node, for internal use only
174
+
175
+ Args:
176
+ source_node_id (str): source node id
177
+ operator_version: (str): the version of the operator
178
+ schema: (typing.Dict): the schema for the data to be ingested
179
+ Returns:
180
+ dict (typing.Dict): output node
181
+ """
182
+ return {
187
183
"node_id" : str (uuid4 ()),
188
184
"type" : "TRANSFORM" ,
189
185
"operator" : f"sagemaker.spark.infer_and_cast_type_{ operator_version } " ,
186
+ "trained_parameters" : {} if schema is None else schema ,
190
187
"parameters" : {},
191
188
"inputs" : [
192
- {"name" : "default" , "node_id" : source_node [ "node_id" ] , "output_name" : "default" }
189
+ {"name" : "default" , "node_id" : source_node_id , "output_name" : "default" }
193
190
],
194
191
"outputs" : [{"name" : "default" }],
195
192
}
196
-
197
- if schema :
198
- output_node ["trained_parameters" ] = schema
199
-
200
- flow = {
201
- "metadata" : {"version" : 1 , "disable_limits" : False },
202
- "nodes" : [source_node , output_node ],
203
- }
204
-
205
- return flow , f'{ output_node ["node_id" ]} .default'
0 commit comments