Skip to content

Commit 14b743a

Browse files
author
Payton Staub
committed
Fix multiple output parameters in callback step properties. Additional unit tests
1 parent 8c50d92 commit 14b743a

File tree

2 files changed

+105
-3
lines changed

2 files changed

+105
-3
lines changed

src/sagemaker/workflow/callback_step.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,14 @@ def __init__(
105105

106106
root_path = f"Steps.{name}"
107107
root_prop = Properties(path=root_path)
108+
109+
property_dict = {}
108110
for output in outputs:
109-
property_dict = {}
110111
property_dict[output.output_name] = Properties(
111112
f"{root_path}.OutputParameters['{output.output_name}']"
112113
)
113-
root_prop.__dict__["Outputs"] = property_dict
114+
115+
root_prop.__dict__["Outputs"] = property_dict
114116
self._properties = root_prop
115117

116118
@property

tests/unit/sagemaker/workflow/test_callback_step.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,21 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
from sagemaker.workflow.parameters import ParameterInteger
15+
import json
16+
17+
import pytest
18+
19+
from mock import Mock
20+
21+
from sagemaker.workflow.parameters import ParameterInteger, ParameterString
22+
from sagemaker.workflow.pipeline import Pipeline
1623
from sagemaker.workflow.callback_step import CallbackStep, CallbackOutput, CallbackOutputTypeEnum
1724

25+
from tests.unit.sagemaker.workflow.helpers import ordered
26+
27+
@pytest.fixture
28+
def sagemaker_session_mock():
29+
return Mock()
1830

1931
def test_callback_step():
2032
param = ParameterInteger(name="MyInt")
@@ -39,3 +51,91 @@ def test_callback_step():
3951
],
4052
"Arguments": {"arg1": "foo", "arg2": 5, "arg3": param},
4153
}
54+
55+
def test_callback_step_output_expr():
56+
param = ParameterInteger(name="MyInt")
57+
outputParam1 = CallbackOutput(output_name="output1", output_type=CallbackOutputTypeEnum.String)
58+
outputParam2 = CallbackOutput(output_name="output2", output_type=CallbackOutputTypeEnum.Boolean)
59+
cb_step = CallbackStep(
60+
name="MyCallbackStep",
61+
depends_on=["TestStep"],
62+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
63+
inputs={"arg1": "foo", "arg2": 5, "arg3": param},
64+
outputs=[outputParam1, outputParam2],
65+
)
66+
67+
assert cb_step.properties.Outputs['output1'].expr == {"Get": "Steps.MyCallbackStep.OutputParameters['output1']"}
68+
assert cb_step.properties.Outputs['output2'].expr == {"Get": "Steps.MyCallbackStep.OutputParameters['output2']"}
69+
70+
def test_pipeline_interpolates_callback_outputs():
71+
parameter = ParameterString("MyStr")
72+
outputParam1 = CallbackOutput(output_name="output1", output_type=CallbackOutputTypeEnum.String)
73+
outputParam2 = CallbackOutput(output_name="output2", output_type=CallbackOutputTypeEnum.String)
74+
cb_step1 = CallbackStep(
75+
name="MyCallbackStep1",
76+
depends_on=["TestStep"],
77+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
78+
inputs={"arg1": "foo"},
79+
outputs=[outputParam1],
80+
)
81+
cb_step2 = CallbackStep(
82+
name="MyCallbackStep2",
83+
depends_on=["TestStep"],
84+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
85+
inputs={"arg1": cb_step1.properties.Outputs['output1']},
86+
outputs=[outputParam2],
87+
)
88+
89+
pipeline = Pipeline(
90+
name="MyPipeline",
91+
parameters=[parameter],
92+
steps=[cb_step1, cb_step2],
93+
sagemaker_session=sagemaker_session_mock,
94+
)
95+
96+
assert json.loads(pipeline.definition()) == {
97+
"Version": "2020-12-01",
98+
"Metadata": {},
99+
"Parameters": [{"Name": "MyStr", "Type": "String"}],
100+
"PipelineExperimentConfig": {
101+
"ExperimentName": {"Get": "Execution.PipelineName"},
102+
"TrialName": {"Get": "Execution.PipelineExecutionId"},
103+
},
104+
"Steps": [
105+
{
106+
"Name": "MyCallbackStep1",
107+
"Type": "Callback",
108+
"Arguments": {
109+
"arg1": "foo"
110+
},
111+
"DependsOn": [
112+
"TestStep"
113+
],
114+
"SqsQueueUrl": "https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
115+
"OutputParameters": [
116+
{
117+
"OutputName": "output1",
118+
"OutputType": "String"
119+
}
120+
]
121+
},
122+
{
123+
"Name": "MyCallbackStep2",
124+
"Type": "Callback",
125+
"Arguments": {
126+
"arg1": { "Get": "Steps.MyCallbackStep1.OutputParameters['output1']"}
127+
},
128+
"DependsOn": [
129+
"TestStep"
130+
],
131+
"SqsQueueUrl": "https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
132+
"OutputParameters": [
133+
{
134+
"OutputName": "output2",
135+
"OutputType": "String"
136+
}
137+
]
138+
}
139+
]
140+
}
141+

0 commit comments

Comments
 (0)