@@ -49,64 +49,17 @@ def test_trial_component_name(viz, sagemaker_session):
49
49
"TrialComponentArn": "tc-arn",
50
50
}
51
51
52
- sagemaker_session.sagemaker_client.list_associations.side_effect = [
53
- {
54
- "AssociationSummaries": [
55
- {
56
- "SourceArn": "a:b:c:d:e:artifact/src-arn-1",
57
- "SourceName": "source-name-1",
58
- "SourceType": "source-type-1",
59
- "DestinationArn": "a:b:c:d:e:artifact/dest-arn-1",
60
- "DestinationName": "dest-name-1",
61
- "DestinationType": "dest-type-1",
62
- "AssociationType": "type-1",
63
- }
64
- ]
65
- },
66
- {
67
- "AssociationSummaries": [
68
- {
69
- "SourceArn": "a:b:c:d:e:artifact/src-arn-2",
70
- "SourceName": "source-name-2",
71
- "SourceType": "source-type-2",
72
- "DestinationArn": "a:b:c:d:e:artifact/dest-arn-2",
73
- "DestinationName": "dest-name-2",
74
- "DestinationType": "dest-type-2",
75
- "AssociationType": "type-2",
76
- }
77
- ]
78
- },
79
- ]
52
+ get_list_associations_side_effect(sagemaker_session)
80
53
81
54
df = viz.show(trial_component_name=name)
82
55
83
56
sagemaker_session.sagemaker_client.describe_trial_component.assert_called_with(
84
57
TrialComponentName=name,
85
58
)
86
59
87
- expected_calls = [
88
- unittest.mock.call(
89
- DestinationArn="tc-arn",
90
- ),
91
- unittest.mock.call(
92
- SourceArn="tc-arn",
93
- ),
94
- ]
95
- assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
60
+ assert_list_associations_mock_calls(sagemaker_session)
96
61
97
- expected_dataframe = pd.DataFrame.from_dict(
98
- OrderedDict(
99
- [
100
- ("Name/Source", ["source-name-1", "dest-name-2"]),
101
- ("Direction", ["Input", "Output"]),
102
- ("Type", ["source-type-1", "dest-type-2"]),
103
- ("Association Type", ["type-1", "type-2"]),
104
- ("Lineage Type", ["artifact", "artifact"]),
105
- ]
106
- )
107
- )
108
-
109
- pd.testing.assert_frame_equal(expected_dataframe, df)
62
+ pd.testing.assert_frame_equal(get_expected_dataframe(), df)
110
63
111
64
112
65
def test_model_package_arn(viz, sagemaker_session):
@@ -116,34 +69,7 @@ def test_model_package_arn(viz, sagemaker_session):
116
69
"ArtifactSummaries": [{"ArtifactArn": "artifact-arn"}]
117
70
}
118
71
119
- sagemaker_session.sagemaker_client.list_associations.side_effect = [
120
- {
121
- "AssociationSummaries": [
122
- {
123
- "SourceArn": "a:b:c:d:e:artifact/src-arn-1",
124
- "SourceName": "source-name-1",
125
- "SourceType": "source-type-1",
126
- "DestinationArn": "a:b:c:d:e:artifact/dest-arn-1",
127
- "DestinationName": "dest-name-1",
128
- "DestinationType": "dest-type-1",
129
- "AssociationType": "type-1",
130
- }
131
- ]
132
- },
133
- {
134
- "AssociationSummaries": [
135
- {
136
- "SourceArn": "a:b:c:d:e:artifact/src-arn-2",
137
- "SourceName": "source-name-2",
138
- "SourceType": "source-type-2",
139
- "DestinationArn": "a:b:c:d:e:artifact/dest-arn-2",
140
- "DestinationName": "dest-name-2",
141
- "DestinationType": "dest-type-2",
142
- "AssociationType": "type-2",
143
- }
144
- ]
145
- },
146
- ]
72
+ get_list_associations_side_effect(sagemaker_session)
147
73
148
74
df = viz.show(model_package_arn=name)
149
75
@@ -161,19 +87,7 @@ def test_model_package_arn(viz, sagemaker_session):
161
87
]
162
88
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
163
89
164
- expected_dataframe = pd.DataFrame.from_dict(
165
- OrderedDict(
166
- [
167
- ("Name/Source", ["source-name-1", "dest-name-2"]),
168
- ("Direction", ["Input", "Output"]),
169
- ("Type", ["source-type-1", "dest-type-2"]),
170
- ("Association Type", ["type-1", "type-2"]),
171
- ("Lineage Type", ["artifact", "artifact"]),
172
- ]
173
- )
174
- )
175
-
176
- pd.testing.assert_frame_equal(expected_dataframe, df)
90
+ pd.testing.assert_frame_equal(get_expected_dataframe(), df)
177
91
178
92
179
93
def test_endpoint_arn(viz, sagemaker_session):
@@ -183,34 +97,7 @@ def test_endpoint_arn(viz, sagemaker_session):
183
97
"ContextSummaries": [{"ContextArn": "context-arn"}]
184
98
}
185
99
186
- sagemaker_session.sagemaker_client.list_associations.side_effect = [
187
- {
188
- "AssociationSummaries": [
189
- {
190
- "SourceArn": "a:b:c:d:e:context/src-arn-1",
191
- "SourceName": "source-name-1",
192
- "SourceType": "source-type-1",
193
- "DestinationArn": "a:b:c:d:e:context/dest-arn-1",
194
- "DestinationName": "dest-name-1",
195
- "DestinationType": "dest-type-1",
196
- "AssociationType": "type-1",
197
- }
198
- ]
199
- },
200
- {
201
- "AssociationSummaries": [
202
- {
203
- "SourceArn": "a:b:c:d:e:context/src-arn-2",
204
- "SourceName": "source-name-2",
205
- "SourceType": "source-type-2",
206
- "DestinationArn": "a:b:c:d:e:context/dest-arn-2",
207
- "DestinationName": "dest-name-2",
208
- "DestinationType": "dest-type-2",
209
- "AssociationType": "type-2",
210
- }
211
- ]
212
- },
213
- ]
100
+ get_list_associations_side_effect(sagemaker_session)
214
101
215
102
df = viz.show(endpoint_arn=name)
216
103
@@ -228,27 +115,74 @@ def test_endpoint_arn(viz, sagemaker_session):
228
115
]
229
116
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
230
117
231
- expected_dataframe = pd.DataFrame.from_dict(
232
- OrderedDict(
233
- [
234
- ("Name/Source", ["source-name-1", "dest-name-2"]),
235
- ("Direction", ["Input", "Output"]),
236
- ("Type", ["source-type-1", "dest-type-2"]),
237
- ("Association Type", ["type-1", "type-2"]),
238
- ("Lineage Type", ["context", "context"]),
239
- ]
240
- )
118
+ pd.testing.assert_frame_equal(get_expected_dataframe(), df)
119
+
120
+
121
+ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
122
+
123
+ sagemaker_session.sagemaker_client.list_trial_components.return_value = {
124
+ "TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
125
+ }
126
+
127
+ get_list_associations_side_effect(sagemaker_session)
128
+
129
+ step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}}
130
+
131
+ df = viz.show(pipeline_execution_step=step)
132
+
133
+ sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
134
+ SourceArn="proc-job-arn",
241
135
)
242
136
243
- pd.testing.assert_frame_equal(expected_dataframe, df )
137
+ assert_list_associations_mock_calls(sagemaker_session )
244
138
139
+ pd.testing.assert_frame_equal(get_expected_dataframe(), df)
245
140
246
- def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
141
+
142
+ def test_training_job_pipeline_execution_step(viz, sagemaker_session):
247
143
248
144
sagemaker_session.sagemaker_client.list_trial_components.return_value = {
249
145
"TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
250
146
}
251
147
148
+ get_list_associations_side_effect(sagemaker_session)
149
+
150
+ step = {"Metadata": {"TrainingJob": {"Arn": "training-job-arn"}}}
151
+
152
+ df = viz.show(pipeline_execution_step=step)
153
+
154
+ sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
155
+ SourceArn="training-job-arn",
156
+ )
157
+
158
+ assert_list_associations_mock_calls(sagemaker_session)
159
+
160
+ pd.testing.assert_frame_equal(get_expected_dataframe(), df)
161
+
162
+
163
+ def test_transform_job_pipeline_execution_step(viz, sagemaker_session):
164
+
165
+ sagemaker_session.sagemaker_client.list_trial_components.return_value = {
166
+ "TrialComponentSummaries": [{"TrialComponentArn": "tc-arn"}]
167
+ }
168
+
169
+ get_list_associations_side_effect(sagemaker_session)
170
+
171
+ step = {"Metadata": {"TransformJob": {"Arn": "transform-job-arn"}}}
172
+
173
+ df = viz.show(pipeline_execution_step=step)
174
+
175
+ sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
176
+ SourceArn="transform-job-arn",
177
+ )
178
+
179
+ assert_list_associations_mock_calls(sagemaker_session)
180
+
181
+ pd.testing.assert_frame_equal(get_expected_dataframe(), df)
182
+
183
+
184
+ def get_list_associations_side_effect(sagemaker_session):
185
+
252
186
sagemaker_session.sagemaker_client.list_associations.side_effect = [
253
187
{
254
188
"AssociationSummaries": [
@@ -278,13 +212,8 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
278
212
},
279
213
]
280
214
281
- step = {"Metadata": {"ProcessingJob": {"Arn": "proc-job-arn"}}}
282
-
283
- df = viz.show(pipeline_execution_step=step)
284
215
285
- sagemaker_session.sagemaker_client.list_trial_components.assert_called_with(
286
- SourceArn="proc-job-arn",
287
- )
216
+ def assert_list_associations_mock_calls(sagemaker_session):
288
217
289
218
expected_calls = [
290
219
unittest.mock.call(
@@ -296,6 +225,9 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
296
225
]
297
226
assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
298
227
228
+
229
+ def get_expected_dataframe():
230
+
299
231
expected_dataframe = pd.DataFrame.from_dict(
300
232
OrderedDict(
301
233
[
@@ -308,4 +240,4 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
308
240
)
309
241
)
310
242
311
- pd.testing.assert_frame_equal(expected_dataframe, df)
243
+ return expected_dataframe
0 commit comments