24
24
25
25
from sagemaker ._studio import _append_project_tags
26
26
from sagemaker .session import Session
27
- from sagemaker .workflow .callback_step import CallbackOutput
27
+ from sagemaker .workflow .callback_step import CallbackOutput , CallbackStep
28
28
from sagemaker .workflow .entities import (
29
29
Entity ,
30
30
Expression ,
@@ -242,7 +242,10 @@ def definition(self) -> str:
242
242
request_dict ["PipelineExperimentConfig" ] = interpolate (
243
243
request_dict ["PipelineExperimentConfig" ]
244
244
)
245
- request_dict ["Steps" ] = interpolate (request_dict ["Steps" ])
245
+ callback_output_to_step_map = _map_callback_outputs (self .steps )
246
+ request_dict ["Steps" ] = interpolate (
247
+ request_dict ["Steps" ], callback_output_to_step_map = callback_output_to_step_map
248
+ )
246
249
247
250
return json .dumps (request_dict )
248
251
@@ -263,7 +266,7 @@ def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
263
266
return [{"Name" : name , "Value" : str (value )} for name , value in parameters .items ()]
264
267
265
268
266
- def interpolate (request_obj : RequestType ) -> RequestType :
269
+ def interpolate (request_obj : RequestType , ** kwargs ) -> RequestType :
267
270
"""Replaces Parameter values in a list of nested Dict[str, Any] with their workflow expression.
268
271
269
272
Args:
@@ -273,28 +276,59 @@ def interpolate(request_obj: RequestType) -> RequestType:
273
276
RequestType: The request dict with Parameter values replaced by their expression.
274
277
"""
275
278
request_obj_copy = deepcopy (request_obj )
276
- return _interpolate (request_obj_copy )
279
+ return _interpolate (
280
+ request_obj_copy ,
281
+ callback_output_to_step_map = kwargs .get ("callback_output_to_step_map" , None ),
282
+ )
277
283
278
284
279
- def _interpolate (obj : Union [RequestType , Any ]):
285
+ def _interpolate (obj : Union [RequestType , Any ], ** kwargs ):
280
286
"""Walks the nested request dict, replacing Parameter type values with workflow expressions.
281
287
282
288
Args:
283
289
obj (Union[RequestType, Any]): The request dict.
284
290
"""
285
- if isinstance (obj , (Expression , Parameter , Properties , CallbackOutput )):
291
+ if isinstance (obj , (Expression , Parameter , Properties )):
286
292
return obj .expr
293
+ if isinstance (obj , CallbackOutput ):
294
+ callback_output_to_step_map = kwargs .get ("callback_output_to_step_map" , {})
295
+ step_name = callback_output_to_step_map [obj .output_name ]
296
+ return obj .expr (step_name )
287
297
if isinstance (obj , dict ):
288
298
new = obj .__class__ ()
289
299
for key , value in obj .items ():
290
- new [key ] = interpolate (value )
300
+ new [key ] = interpolate (
301
+ value , callback_output_to_step_map = kwargs .get ("callback_output_to_step_map" , None )
302
+ )
291
303
elif isinstance (obj , (list , set , tuple )):
292
- new = obj .__class__ (interpolate (value ) for value in obj )
304
+ new = obj .__class__ (
305
+ interpolate (
306
+ value , callback_output_to_step_map = kwargs .get ("callback_output_to_step_map" , None )
307
+ )
308
+ for value in obj
309
+ )
293
310
else :
294
311
return obj
295
312
return new
296
313
297
314
315
+ def _map_callback_outputs (steps : List [Step ]):
316
+ """Iterate over the provided steps, building a map of callback output parameters to step names.
317
+
318
+ Args:
319
+ step (List[Step]): The steps list.
320
+ """
321
+
322
+ callback_output_map = {}
323
+ for step in steps :
324
+ if isinstance (step , CallbackStep ):
325
+ if step .outputs :
326
+ for output in step .outputs :
327
+ callback_output_map [output .output_name ] = step .name
328
+
329
+ return callback_output_map
330
+
331
+
298
332
def update_args (args : Dict [str , Any ], ** kwargs ):
299
333
"""Updates the request arguments dict with a value, if populated.
300
334
0 commit comments