Skip to content

Commit 8c50d92

Browse files
author
Payton Staub
committed
Fix referencing of CallbackOutputs in other steps
1 parent 03bbd28 commit 8c50d92

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

src/sagemaker/workflow/callback_step.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import List
16+
from typing import List, Dict
1717
from enum import Enum
1818

1919
import attr
@@ -59,6 +59,20 @@ def to_request(self) -> RequestType:
5959
"OutputType": self.output_type.value,
6060
}
6161

62+
@property
63+
def expr(self) -> Dict[str, str]:
64+
"""The 'Get' expression dict for a `Parameter`."""
65+
return CallbackOutput._expr(self.output_name)
66+
67+
@classmethod
68+
def _expr(cls, name):
69+
"""An internal classmethod for the 'Get' expression dict for a `CallbackOutput`.
70+
71+
Args:
72+
name (str): The name of the callback output.
73+
"""
74+
return {"Get": f"Steps.{name}.OutputParameters['{name}']"}
75+
6276

6377
class CallbackStep(Step):
6478
"""Callback step for workflow."""
@@ -91,7 +105,12 @@ def __init__(
91105

92106
root_path = f"Steps.{name}"
93107
root_prop = Properties(path=root_path)
94-
root_prop.__dict__["OutputParameters"] = Properties(f"{root_path}.OutputParameters")
108+
for output in outputs:
109+
property_dict = {}
110+
property_dict[output.output_name] = Properties(
111+
f"{root_path}.OutputParameters['{output.output_name}']"
112+
)
113+
root_prop.__dict__["Outputs"] = property_dict
95114
self._properties = root_prop
96115

97116
@property

src/sagemaker/workflow/pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from sagemaker._studio import _append_project_tags
2626
from sagemaker.session import Session
27+
from sagemaker.workflow.callback_step import CallbackOutput
2728
from sagemaker.workflow.entities import (
2829
Entity,
2930
Expression,
@@ -281,7 +282,7 @@ def _interpolate(obj: Union[RequestType, Any]):
281282
Args:
282283
obj (Union[RequestType, Any]): The request dict.
283284
"""
284-
if isinstance(obj, (Expression, Parameter, Properties)):
285+
if isinstance(obj, (Expression, Parameter, Properties, CallbackOutput)):
285286
return obj.expr
286287
if isinstance(obj, dict):
287288
new = obj.__class__()

0 commit comments

Comments
 (0)