Skip to content

Commit a05b10b

Browse files
staubhpPayton Staubicywang86rui
authored
fix: Correctly interpolate Callback output parameters (aws#2467)
* Correctly interpolate Callback output parameters that are passed by reference * Correctly interpolate Callback output parameters that are passed by reference * Address PR comments * Update api docs * Add missing positional arg to integ test Co-authored-by: Payton Staub <[email protected]> Co-authored-by: icywang86rui <[email protected]>
1 parent 81afe34 commit a05b10b

File tree

4 files changed

+85
-16
lines changed

4 files changed

+85
-16
lines changed

src/sagemaker/workflow/callback_step.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,20 @@ def to_request(self) -> RequestType:
5858
"OutputType": self.output_type.value,
5959
}
6060

61-
@property
62-
def expr(self) -> Dict[str, str]:
63-
"""The 'Get' expression dict for a `Parameter`."""
64-
return CallbackOutput._expr(self.output_name)
61+
def expr(self, step_name) -> Dict[str, str]:
62+
"""The 'Get' expression dict for a `CallbackOutput`."""
63+
return CallbackOutput._expr(self.output_name, step_name)
6564

6665
@classmethod
67-
def _expr(cls, name):
66+
def _expr(cls, name, step_name):
6867
"""An internal classmethod for the 'Get' expression dict for a `CallbackOutput`.
6968
7069
Args:
7170
name (str): The name of the callback output.
71+
step_name (str): The name of the step the callback step associated
72+
with this output belongs to.
7273
"""
73-
return {"Get": f"Steps.{name}.OutputParameters['{name}']"}
74+
return {"Get": f"Steps.{step_name}.OutputParameters['{name}']"}
7475

7576

7677
class CallbackStep(Step):

src/sagemaker/workflow/pipeline.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from sagemaker._studio import _append_project_tags
2626
from sagemaker.session import Session
27-
from sagemaker.workflow.callback_step import CallbackOutput
27+
from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep
2828
from sagemaker.workflow.entities import (
2929
Entity,
3030
Expression,
@@ -240,9 +240,12 @@ def definition(self) -> str:
240240
"""Converts a request structure to string representation for workflow service calls."""
241241
request_dict = self.to_request()
242242
request_dict["PipelineExperimentConfig"] = interpolate(
243-
request_dict["PipelineExperimentConfig"]
243+
request_dict["PipelineExperimentConfig"], {}
244+
)
245+
callback_output_to_step_map = _map_callback_outputs(self.steps)
246+
request_dict["Steps"] = interpolate(
247+
request_dict["Steps"], callback_output_to_step_map=callback_output_to_step_map
244248
)
245-
request_dict["Steps"] = interpolate(request_dict["Steps"])
246249

247250
return json.dumps(request_dict)
248251

@@ -263,38 +266,62 @@ def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
263266
return [{"Name": name, "Value": str(value)} for name, value in parameters.items()]
264267

265268

266-
def interpolate(request_obj: RequestType) -> RequestType:
269+
def interpolate(
270+
request_obj: RequestType, callback_output_to_step_map: Dict[str, str]
271+
) -> RequestType:
267272
"""Replaces Parameter values in a list of nested Dict[str, Any] with their workflow expression.
268273
269274
Args:
270275
request_obj (RequestType): The request dict.
276+
callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
271277
272278
Returns:
273279
RequestType: The request dict with Parameter values replaced by their expression.
274280
"""
275281
request_obj_copy = deepcopy(request_obj)
276-
return _interpolate(request_obj_copy)
282+
return _interpolate(request_obj_copy, callback_output_to_step_map=callback_output_to_step_map)
277283

278284

279-
def _interpolate(obj: Union[RequestType, Any]):
285+
def _interpolate(obj: Union[RequestType, Any], callback_output_to_step_map: Dict[str, str]):
280286
"""Walks the nested request dict, replacing Parameter type values with workflow expressions.
281287
282288
Args:
283289
obj (Union[RequestType, Any]): The request dict.
290+
callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
284291
"""
285-
if isinstance(obj, (Expression, Parameter, Properties, CallbackOutput)):
292+
if isinstance(obj, (Expression, Parameter, Properties)):
286293
return obj.expr
294+
if isinstance(obj, CallbackOutput):
295+
step_name = callback_output_to_step_map[obj.output_name]
296+
return obj.expr(step_name)
287297
if isinstance(obj, dict):
288298
new = obj.__class__()
289299
for key, value in obj.items():
290-
new[key] = interpolate(value)
300+
new[key] = interpolate(value, callback_output_to_step_map)
291301
elif isinstance(obj, (list, set, tuple)):
292-
new = obj.__class__(interpolate(value) for value in obj)
302+
new = obj.__class__(interpolate(value, callback_output_to_step_map) for value in obj)
293303
else:
294304
return obj
295305
return new
296306

297307

308+
def _map_callback_outputs(steps: List[Step]):
309+
"""Iterate over the provided steps, building a map of callback output parameters to step names.
310+
311+
Args:
312+
step (List[Step]): The steps list.
313+
"""
314+
315+
callback_output_map = {}
316+
for step in steps:
317+
if isinstance(step, CallbackStep):
318+
if step.outputs:
319+
for output in step.outputs:
320+
callback_output_map[output.output_name] = step.name
321+
322+
return callback_output_map
323+
324+
298325
def update_args(args: Dict[str, Any], **kwargs):
299326
"""Updates the request arguments dict with a value, if populated.
300327

tests/integ/test_workflow.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,47 @@ def test_one_step_callback_pipeline(sagemaker_session, role, pipeline_name, regi
739739
pass
740740

741741

742+
def test_two_step_callback_pipeline_with_output_reference(
743+
sagemaker_session, role, pipeline_name, region_name
744+
):
745+
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
746+
747+
outputParam1 = CallbackOutput(output_name="output1", output_type=CallbackOutputTypeEnum.String)
748+
step_callback1 = CallbackStep(
749+
name="callback-step1",
750+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
751+
inputs={"arg1": "foo"},
752+
outputs=[outputParam1],
753+
)
754+
755+
step_callback2 = CallbackStep(
756+
name="callback-step2",
757+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
758+
inputs={"arg1": outputParam1},
759+
outputs=[],
760+
)
761+
762+
pipeline = Pipeline(
763+
name=pipeline_name,
764+
parameters=[instance_count],
765+
steps=[step_callback1, step_callback2],
766+
sagemaker_session=sagemaker_session,
767+
)
768+
769+
try:
770+
response = pipeline.create(role)
771+
create_arn = response["PipelineArn"]
772+
assert re.match(
773+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
774+
create_arn,
775+
)
776+
finally:
777+
try:
778+
pipeline.delete()
779+
except Exception:
780+
pass
781+
782+
742783
def test_conditional_pytorch_training_model_registration(
743784
sagemaker_session,
744785
role,

tests/unit/sagemaker/workflow/test_callback_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_pipeline_interpolates_callback_outputs():
8888
name="MyCallbackStep2",
8989
depends_on=["TestStep"],
9090
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
91-
inputs={"arg1": cb_step1.properties.Outputs["output1"]},
91+
inputs={"arg1": outputParam1},
9292
outputs=[outputParam2],
9393
)
9494

0 commit comments

Comments
 (0)