Skip to content

Commit 7ff222f

Browse files
author
Payton Staub
committed
Correctly interpolate Callback output parameters that are passed by reference
1 parent c2fbe75 commit 7ff222f

File tree

3 files changed

+51
-15
lines changed

3 files changed

+51
-15
lines changed

src/sagemaker/workflow/callback_step.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,18 @@ def to_request(self) -> RequestType:
5858
"OutputType": self.output_type.value,
5959
}
6060

61-
@property
62-
def expr(self) -> Dict[str, str]:
63-
"""The 'Get' expression dict for a `Parameter`."""
64-
return CallbackOutput._expr(self.output_name)
61+
def expr(self, step_name) -> Dict[str, str]:
62+
"""The 'Get' expression dict for a `CallbackOutput`."""
63+
return CallbackOutput._expr(self.output_name, step_name)
6564

6665
@classmethod
67-
def _expr(cls, name):
66+
def _expr(cls, name, step_name):
6867
"""An internal classmethod for the 'Get' expression dict for a `CallbackOutput`.
6968
7069
Args:
7170
name (str): The name of the callback output.
7271
"""
73-
return {"Get": f"Steps.{name}.OutputParameters['{name}']"}
72+
return {"Get": f"Steps.{step_name}.OutputParameters['{name}']"}
7473

7574

7675
class CallbackStep(Step):

src/sagemaker/workflow/pipeline.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from sagemaker._studio import _append_project_tags
2626
from sagemaker.session import Session
27-
from sagemaker.workflow.callback_step import CallbackOutput
27+
from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep
2828
from sagemaker.workflow.entities import (
2929
Entity,
3030
Expression,
@@ -242,7 +242,10 @@ def definition(self) -> str:
242242
request_dict["PipelineExperimentConfig"] = interpolate(
243243
request_dict["PipelineExperimentConfig"]
244244
)
245-
request_dict["Steps"] = interpolate(request_dict["Steps"])
245+
callback_output_to_step_map = _map_callback_outputs(self.steps)
246+
request_dict["Steps"] = interpolate(
247+
request_dict["Steps"], callback_output_to_step_map=callback_output_to_step_map
248+
)
246249

247250
return json.dumps(request_dict)
248251

@@ -263,7 +266,7 @@ def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
263266
return [{"Name": name, "Value": str(value)} for name, value in parameters.items()]
264267

265268

266-
def interpolate(request_obj: RequestType) -> RequestType:
269+
def interpolate(request_obj: RequestType, **kwargs) -> RequestType:
267270
"""Replaces Parameter values in a list of nested Dict[str, Any] with their workflow expression.
268271
269272
Args:
@@ -273,28 +276,62 @@ def interpolate(request_obj: RequestType) -> RequestType:
273276
RequestType: The request dict with Parameter values replaced by their expression.
274277
"""
275278
request_obj_copy = deepcopy(request_obj)
276-
return _interpolate(request_obj_copy)
279+
return _interpolate(
280+
request_obj_copy,
281+
callback_output_to_step_map=kwargs.get("callback_output_to_step_map", None),
282+
)
277283

278284

279-
def _interpolate(obj: Union[RequestType, Any]):
285+
def _interpolate(obj: Union[RequestType, Any], **kwargs):
280286
"""Walks the nested request dict, replacing Parameter type values with workflow expressions.
281287
282288
Args:
283289
obj (Union[RequestType, Any]): The request dict.
284290
"""
285-
if isinstance(obj, (Expression, Parameter, Properties, CallbackOutput)):
291+
if isinstance(obj, (Expression, Parameter, Properties)):
286292
return obj.expr
293+
if isinstance(obj, CallbackOutput):
294+
callback_output_to_step_map = kwargs.get("callback_output_to_step_map", {})
295+
step_name = callback_output_to_step_map[obj.output_name]
296+
return obj.expr(step_name)
287297
if isinstance(obj, dict):
288298
new = obj.__class__()
289299
for key, value in obj.items():
290-
new[key] = interpolate(value)
300+
new[key] = interpolate(
301+
value, callback_output_to_step_map=kwargs.get("callback_output_to_step_map", None)
302+
)
291303
elif isinstance(obj, (list, set, tuple)):
292-
new = obj.__class__(interpolate(value) for value in obj)
304+
new = obj.__class__(
305+
interpolate(
306+
value, callback_output_to_step_map=kwargs.get("callback_output_to_step_map", None)
307+
)
308+
for value in obj
309+
)
293310
else:
294311
return obj
295312
return new
296313

297314

315+
def _map_callback_outputs(steps: List[Step]):
316+
"""Iterate over the provided steps, building a map of callback output parameters to step names.
317+
318+
Args:
319+
step (List[Step]): The steps list.
320+
"""
321+
322+
callback_output_map = {}
323+
for step in steps:
324+
print(f"Processing step {step}")
325+
if isinstance(step, CallbackStep):
326+
print("it is a callback step")
327+
if step.outputs:
328+
print("it has outputs")
329+
for output in step.outputs:
330+
callback_output_map[output.output_name] = step.name
331+
332+
return callback_output_map
333+
334+
298335
def update_args(args: Dict[str, Any], **kwargs):
299336
"""Updates the request arguments dict with a value, if populated.
300337

tests/unit/sagemaker/workflow/test_callback_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_pipeline_interpolates_callback_outputs():
8888
name="MyCallbackStep2",
8989
depends_on=["TestStep"],
9090
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
91-
inputs={"arg1": cb_step1.properties.Outputs["output1"]},
91+
inputs={"arg1": outputParam1},
9292
outputs=[outputParam2],
9393
)
9494

0 commit comments

Comments
 (0)