14
14
15
15
import unittest .mock
16
16
17
- import pytest
18
17
from sagemaker .lineage import visualizer
19
18
import pandas as pd
20
19
from collections import OrderedDict
21
20
22
21
23
- @pytest .fixture
24
- def sagemaker_session ():
25
- return unittest .mock .Mock ()
26
-
27
-
28
- @pytest .fixture
29
- def vizualizer (sagemaker_session ):
30
- return visualizer .LineageTableVisualizer (sagemaker_session )
31
-
32
-
33
- def test_friendly_name_short_uri (vizualizer , sagemaker_session ):
22
+ def test_friendly_name_short_uri (viz , sagemaker_session ):
34
23
uri = "s3://f-069083975568/train.txt"
35
24
arn = "test_arn"
36
25
sagemaker_session .sagemaker_client .describe_artifact .return_value = {
37
26
"Source" : {"SourceUri" : uri , "SourceTypes" : "" }
38
27
}
39
- actual_name = vizualizer ._get_friendly_name (name = None , arn = arn , entity_type = "artifact" )
28
+ actual_name = viz ._get_friendly_name (name = None , arn = arn , entity_type = "artifact" )
40
29
assert uri == actual_name
41
30
42
31
43
- def test_friendly_name_long_uri (vizualizer , sagemaker_session ):
32
+ def test_friendly_name_long_uri (viz , sagemaker_session ):
44
33
uri = (
45
34
"s3://flintstone-end-to-end-tests-gamma-us-west-2-069083975568/results/canary-auto-1608761252626/"
46
35
"preprocessed-data/tuning_data/train.txt"
@@ -49,12 +38,12 @@ def test_friendly_name_long_uri(vizualizer, sagemaker_session):
49
38
sagemaker_session .sagemaker_client .describe_artifact .return_value = {
50
39
"Source" : {"SourceUri" : uri , "SourceTypes" : "" }
51
40
}
52
- actual_name = vizualizer ._get_friendly_name (name = None , arn = arn , entity_type = "artifact" )
41
+ actual_name = viz ._get_friendly_name (name = None , arn = arn , entity_type = "artifact" )
53
42
expected_name = "s3://.../preprocessed-data/tuning_data/train.txt"
54
43
assert expected_name == actual_name
55
44
56
45
57
- def test_trial_component_name (sagemaker_session , vizualizer ):
46
+ def test_trial_component_name (viz , sagemaker_session ):
58
47
name = "tc-name"
59
48
60
49
sagemaker_session .sagemaker_client .describe_trial_component .return_value = {
@@ -90,7 +79,7 @@ def test_trial_component_name(sagemaker_session, vizualizer):
90
79
},
91
80
]
92
81
93
- df = vizualizer .show (trial_component_name = name )
82
+ df = viz .show (trial_component_name = name )
94
83
95
84
sagemaker_session .sagemaker_client .describe_trial_component .assert_called_with (
96
85
TrialComponentName = name ,
@@ -119,3 +108,71 @@ def test_trial_component_name(sagemaker_session, vizualizer):
119
108
)
120
109
121
110
pd .testing .assert_frame_equal (expected_dataframe , df )
111
+
112
+
113
+ def test_processing_job_pipeline_execution_step (viz , sagemaker_session ):
114
+
115
+ sagemaker_session .sagemaker_client .list_trial_components .return_value = {
116
+ "TrialComponentSummaries" : [{"TrialComponentArn" : "tc-arn" }]
117
+ }
118
+
119
+ sagemaker_session .sagemaker_client .list_associations .side_effect = [
120
+ {
121
+ "AssociationSummaries" : [
122
+ {
123
+ "SourceArn" : "a:b:c:d:e:artifact/src-arn-1" ,
124
+ "SourceName" : "source-name-1" ,
125
+ "SourceType" : "source-type-1" ,
126
+ "DestinationArn" : "a:b:c:d:e:artifact/dest-arn-1" ,
127
+ "DestinationName" : "dest-name-1" ,
128
+ "DestinationType" : "dest-type-1" ,
129
+ "AssociationType" : "type-1" ,
130
+ }
131
+ ]
132
+ },
133
+ {
134
+ "AssociationSummaries" : [
135
+ {
136
+ "SourceArn" : "a:b:c:d:e:artifact/src-arn-2" ,
137
+ "SourceName" : "source-name-2" ,
138
+ "SourceType" : "source-type-2" ,
139
+ "DestinationArn" : "a:b:c:d:e:artifact/dest-arn-2" ,
140
+ "DestinationName" : "dest-name-2" ,
141
+ "DestinationType" : "dest-type-2" ,
142
+ "AssociationType" : "type-2" ,
143
+ }
144
+ ]
145
+ },
146
+ ]
147
+
148
+ step = {"Metadata" : {"ProcessingJob" : {"Arn" : "proc-job-arn" }}}
149
+
150
+ df = viz .show (pipeline_execution_step = step )
151
+
152
+ sagemaker_session .sagemaker_client .list_trial_components .assert_called_with (
153
+ SourceArn = "proc-job-arn" ,
154
+ )
155
+
156
+ expected_calls = [
157
+ unittest .mock .call (
158
+ DestinationArn = "tc-arn" ,
159
+ ),
160
+ unittest .mock .call (
161
+ SourceArn = "tc-arn" ,
162
+ ),
163
+ ]
164
+ assert expected_calls == sagemaker_session .sagemaker_client .list_associations .mock_calls
165
+
166
+ expected_dataframe = pd .DataFrame .from_dict (
167
+ OrderedDict (
168
+ [
169
+ ("Name/Source" , ["source-name-1" , "dest-name-2" ]),
170
+ ("Direction" , ["Input" , "Output" ]),
171
+ ("Type" , ["source-type-1" , "dest-type-2" ]),
172
+ ("Association Type" , ["type-1" , "type-2" ]),
173
+ ("Lineage Type" , ["artifact" , "artifact" ]),
174
+ ]
175
+ )
176
+ )
177
+
178
+ pd .testing .assert_frame_equal (expected_dataframe , df )
0 commit comments