Skip to content

Commit 3a9fa41

Browse files
author
Payton Staub
committed
Address PR comments
1 parent 6deec71 commit 3a9fa41

File tree

3 files changed

+51
-17
lines changed

3 files changed

+51
-17
lines changed

src/sagemaker/workflow/callback_step.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def _expr(cls, name, step_name):
6868
6969
Args:
7070
name (str): The name of the callback output.
71+
step_name (str): The name of the step the callback step associated
72+
with this output belongs to.
7173
"""
7274
return {"Get": f"Steps.{step_name}.OutputParameters['{name}']"}
7375

src/sagemaker/workflow/pipeline.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ def definition(self) -> str:
240240
"""Converts a request structure to string representation for workflow service calls."""
241241
request_dict = self.to_request()
242242
request_dict["PipelineExperimentConfig"] = interpolate(
243-
request_dict["PipelineExperimentConfig"]
243+
request_dict["PipelineExperimentConfig"],
244+
{}
244245
)
245246
callback_output_to_step_map = _map_callback_outputs(self.steps)
246247
request_dict["Steps"] = interpolate(
@@ -266,7 +267,9 @@ def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
266267
return [{"Name": name, "Value": str(value)} for name, value in parameters.items()]
267268

268269

269-
def interpolate(request_obj: RequestType, **kwargs) -> RequestType:
270+
def interpolate(
271+
request_obj: RequestType, callback_output_to_step_map: Dict[str, str]
272+
) -> RequestType:
270273
"""Replaces Parameter values in a list of nested Dict[str, Any] with their workflow expression.
271274
272275
Args:
@@ -276,13 +279,10 @@ def interpolate(request_obj: RequestType, **kwargs) -> RequestType:
276279
RequestType: The request dict with Parameter values replaced by their expression.
277280
"""
278281
request_obj_copy = deepcopy(request_obj)
279-
return _interpolate(
280-
request_obj_copy,
281-
callback_output_to_step_map=kwargs.get("callback_output_to_step_map", None),
282-
)
282+
return _interpolate(request_obj_copy, callback_output_to_step_map=callback_output_to_step_map)
283283

284284

285-
def _interpolate(obj: Union[RequestType, Any], **kwargs):
285+
def _interpolate(obj: Union[RequestType, Any], callback_output_to_step_map: Dict[str, str]):
286286
"""Walks the nested request dict, replacing Parameter type values with workflow expressions.
287287
288288
Args:
@@ -291,22 +291,14 @@ def _interpolate(obj: Union[RequestType, Any], **kwargs):
291291
if isinstance(obj, (Expression, Parameter, Properties)):
292292
return obj.expr
293293
if isinstance(obj, CallbackOutput):
294-
callback_output_to_step_map = kwargs.get("callback_output_to_step_map", {})
295294
step_name = callback_output_to_step_map[obj.output_name]
296295
return obj.expr(step_name)
297296
if isinstance(obj, dict):
298297
new = obj.__class__()
299298
for key, value in obj.items():
300-
new[key] = interpolate(
301-
value, callback_output_to_step_map=kwargs.get("callback_output_to_step_map", None)
302-
)
299+
new[key] = interpolate(value, callback_output_to_step_map)
303300
elif isinstance(obj, (list, set, tuple)):
304-
new = obj.__class__(
305-
interpolate(
306-
value, callback_output_to_step_map=kwargs.get("callback_output_to_step_map", None)
307-
)
308-
for value in obj
309-
)
301+
new = obj.__class__(interpolate(value, callback_output_to_step_map) for value in obj)
310302
else:
311303
return obj
312304
return new

tests/integ/test_workflow.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,46 @@ def test_one_step_callback_pipeline(sagemaker_session, role, pipeline_name, regi
739739
pass
740740

741741

742+
def test_two_step_callback_pipeline_with_output_reference(
743+
sagemaker_session, role, pipeline_name, region_name
744+
):
745+
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
746+
747+
outputParam1 = CallbackOutput(output_name="output1", output_type=CallbackOutputTypeEnum.String)
748+
step_callback1 = CallbackStep(
749+
name="callback-step1",
750+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
751+
inputs={"arg1": "foo"},
752+
outputs=[outputParam1],
753+
)
754+
755+
step_callback2 = CallbackStep(
756+
name="callback-step2",
757+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
758+
inputs={"arg1": outputParam1},
759+
)
760+
761+
pipeline = Pipeline(
762+
name=pipeline_name,
763+
parameters=[instance_count],
764+
steps=[step_callback1, step_callback2],
765+
sagemaker_session=sagemaker_session,
766+
)
767+
768+
try:
769+
response = pipeline.create(role)
770+
create_arn = response["PipelineArn"]
771+
assert re.match(
772+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
773+
create_arn,
774+
)
775+
finally:
776+
try:
777+
pipeline.delete()
778+
except Exception:
779+
pass
780+
781+
742782
def test_conditional_pytorch_training_model_registration(
743783
sagemaker_session,
744784
role,

0 commit comments

Comments
 (0)