12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
14
15
- from sagemaker .workflow .parameters import ParameterInteger
15
+ import json
16
+
17
+ import pytest
18
+
19
+ from mock import Mock
20
+
21
+ from sagemaker .workflow .parameters import ParameterInteger , ParameterString
22
+ from sagemaker .workflow .pipeline import Pipeline
16
23
from sagemaker .workflow .callback_step import CallbackStep , CallbackOutput , CallbackOutputTypeEnum
17
24
25
+ from tests .unit .sagemaker .workflow .helpers import ordered
26
+
27
+ @pytest .fixture
28
+ def sagemaker_session_mock ():
29
+ return Mock ()
18
30
19
31
def test_callback_step ():
20
32
param = ParameterInteger (name = "MyInt" )
@@ -39,3 +51,91 @@ def test_callback_step():
39
51
],
40
52
"Arguments" : {"arg1" : "foo" , "arg2" : 5 , "arg3" : param },
41
53
}
54
+
55
+ def test_callback_step_output_expr ():
56
+ param = ParameterInteger (name = "MyInt" )
57
+ outputParam1 = CallbackOutput (output_name = "output1" , output_type = CallbackOutputTypeEnum .String )
58
+ outputParam2 = CallbackOutput (output_name = "output2" , output_type = CallbackOutputTypeEnum .Boolean )
59
+ cb_step = CallbackStep (
60
+ name = "MyCallbackStep" ,
61
+ depends_on = ["TestStep" ],
62
+ sqs_queue_url = "https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue" ,
63
+ inputs = {"arg1" : "foo" , "arg2" : 5 , "arg3" : param },
64
+ outputs = [outputParam1 , outputParam2 ],
65
+ )
66
+
67
+ assert cb_step .properties .Outputs ['output1' ].expr == {"Get" : "Steps.MyCallbackStep.OutputParameters['output1']" }
68
+ assert cb_step .properties .Outputs ['output2' ].expr == {"Get" : "Steps.MyCallbackStep.OutputParameters['output2']" }
69
+
70
+ def test_pipeline_interpolates_callback_outputs ():
71
+ parameter = ParameterString ("MyStr" )
72
+ outputParam1 = CallbackOutput (output_name = "output1" , output_type = CallbackOutputTypeEnum .String )
73
+ outputParam2 = CallbackOutput (output_name = "output2" , output_type = CallbackOutputTypeEnum .String )
74
+ cb_step1 = CallbackStep (
75
+ name = "MyCallbackStep1" ,
76
+ depends_on = ["TestStep" ],
77
+ sqs_queue_url = "https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue" ,
78
+ inputs = {"arg1" : "foo" },
79
+ outputs = [outputParam1 ],
80
+ )
81
+ cb_step2 = CallbackStep (
82
+ name = "MyCallbackStep2" ,
83
+ depends_on = ["TestStep" ],
84
+ sqs_queue_url = "https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue" ,
85
+ inputs = {"arg1" : cb_step1 .properties .Outputs ['output1' ]},
86
+ outputs = [outputParam2 ],
87
+ )
88
+
89
+ pipeline = Pipeline (
90
+ name = "MyPipeline" ,
91
+ parameters = [parameter ],
92
+ steps = [cb_step1 , cb_step2 ],
93
+ sagemaker_session = sagemaker_session_mock ,
94
+ )
95
+
96
+ assert json .loads (pipeline .definition ()) == {
97
+ "Version" : "2020-12-01" ,
98
+ "Metadata" : {},
99
+ "Parameters" : [{"Name" : "MyStr" , "Type" : "String" }],
100
+ "PipelineExperimentConfig" : {
101
+ "ExperimentName" : {"Get" : "Execution.PipelineName" },
102
+ "TrialName" : {"Get" : "Execution.PipelineExecutionId" },
103
+ },
104
+ "Steps" : [
105
+ {
106
+ "Name" : "MyCallbackStep1" ,
107
+ "Type" : "Callback" ,
108
+ "Arguments" : {
109
+ "arg1" : "foo"
110
+ },
111
+ "DependsOn" : [
112
+ "TestStep"
113
+ ],
114
+ "SqsQueueUrl" : "https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue" ,
115
+ "OutputParameters" : [
116
+ {
117
+ "OutputName" : "output1" ,
118
+ "OutputType" : "String"
119
+ }
120
+ ]
121
+ },
122
+ {
123
+ "Name" : "MyCallbackStep2" ,
124
+ "Type" : "Callback" ,
125
+ "Arguments" : {
126
+ "arg1" : { "Get" : "Steps.MyCallbackStep1.OutputParameters['output1']" }
127
+ },
128
+ "DependsOn" : [
129
+ "TestStep"
130
+ ],
131
+ "SqsQueueUrl" : "https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue" ,
132
+ "OutputParameters" : [
133
+ {
134
+ "OutputName" : "output2" ,
135
+ "OutputType" : "String"
136
+ }
137
+ ]
138
+ }
139
+ ]
140
+ }
141
+
0 commit comments