Skip to content

Commit 73f2d4f

Browse files
author
Yi-Ting Lee
committed
resolve integ test conflict
2 parents 5609e42 + 7cec38d commit 73f2d4f

File tree

4 files changed

+366
-1
lines changed

4 files changed

+366
-1
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
urllib3==1.26.8
22
docker-compose==1.29.2
33
docker~=5.0.0
4-
PyYAML==5.4.1
4+
PyYAML==5.4.1

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ fabric==2.6.0
1818
requests==2.27.1
1919
sagemaker-experiments==0.1.35
2020
Jinja2==3.0.3
21+
pyvis==0.2.1

tests/integ/sagemaker/lineage/helpers.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,93 @@ def visit(arn, visited: set):
8080
ret = []
8181
return visit(start_arn, set())
8282

83+
84+
class LineageResourceHelper:
85+
def __init__(self, sagemaker_session):
86+
self.client = sagemaker_session.sagemaker_client
87+
self.artifacts = []
88+
self.actions = []
89+
self.contexts = []
90+
self.associations = []
91+
92+
def create_artifact(self, artifact_name, artifact_type="Dataset"):
93+
response = self.client.create_artifact(
94+
ArtifactName=artifact_name,
95+
Source={
96+
"SourceUri": "Test-artifact-" + artifact_name,
97+
"SourceTypes": [
98+
{"SourceIdType": "S3ETag", "Value": "Test-artifact-sourceId-value"},
99+
],
100+
},
101+
ArtifactType=artifact_type,
102+
)
103+
self.artifacts.append(response["ArtifactArn"])
104+
105+
return response["ArtifactArn"]
106+
107+
def create_action(self, action_name, action_type="ModelDeployment"):
108+
response = self.client.create_action(
109+
ActionName=action_name,
110+
Source={
111+
"SourceUri": "Test-action-" + action_name,
112+
"SourceType": "S3ETag",
113+
"SourceId": "Test-action-sourceId-value",
114+
},
115+
ActionType=action_type,
116+
)
117+
self.actions.append(response["ActionArn"])
118+
119+
return response["ActionArn"]
120+
121+
def create_context(self, context_name, context_type="Endpoint"):
122+
response = self.client.create_context(
123+
ContextName=context_name,
124+
Source={
125+
"SourceUri": "Test-context-" + context_name,
126+
"SourceType": "S3ETag",
127+
"SourceId": "Test-context-sourceId-value",
128+
},
129+
ContextType=context_type,
130+
)
131+
self.contexts.append(response["ContextArn"])
132+
133+
return response["ContextArn"]
134+
135+
def create_association(self, source_arn, dest_arn, association_type="AssociatedWith"):
136+
response = self.client.add_association(
137+
SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_type
138+
)
139+
if "SourceArn" in response.keys():
140+
self.associations.append((source_arn, dest_arn))
141+
return True
142+
else:
143+
return False
144+
145+
def clean_all(self):
146+
for source, dest in self.associations:
147+
try:
148+
self.client.delete_association(SourceArn=source, DestinationArn=dest)
149+
time.sleep(0.5)
150+
except Exception as e:
151+
print("skipped " + str(e))
152+
153+
for artifact_arn in self.artifacts:
154+
try:
155+
self.client.delete_artifact(ArtifactArn=artifact_arn)
156+
time.sleep(0.5)
157+
except Exception as e:
158+
print("skipped " + str(e))
159+
160+
for action_arn in self.actions:
161+
try:
162+
self.client.delete_action(ActionArn=action_arn)
163+
time.sleep(0.5)
164+
except Exception as e:
165+
print("skipped " + str(e))
166+
167+
for context_arn in self.contexts:
168+
try:
169+
self.client.delete_context(ContextArn=context_arn)
170+
time.sleep(0.5)
171+
except Exception as e:
172+
print("skipped " + str(e))
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains code to test SageMaker ``LineageQueryResult.visualize()``"""
14+
from __future__ import absolute_import
15+
import time
16+
import json
17+
import os
18+
19+
import pytest
20+
21+
import sagemaker.lineage.query
22+
from sagemaker.lineage.query import LineageQueryDirectionEnum
23+
from tests.integ.sagemaker.lineage.helpers import name, LineageResourceHelper
24+
25+
26+
def test_LineageResourceHelper(sagemaker_session):
27+
# check if LineageResourceHelper works properly
28+
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)
29+
try:
30+
art1 = lineage_resource_helper.create_artifact(artifact_name=name())
31+
art2 = lineage_resource_helper.create_artifact(artifact_name=name())
32+
lineage_resource_helper.create_association(source_arn=art1, dest_arn=art2)
33+
lineage_resource_helper.clean_all()
34+
except Exception as e:
35+
print(e)
36+
assert False
37+
38+
39+
@pytest.mark.skip("visualizer load test")
40+
def test_wide_graph_visualize(sagemaker_session):
41+
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)
42+
wide_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name())
43+
44+
# create wide graph
45+
# Artifact ----> Artifact
46+
# \ \ \-> Artifact
47+
# \ \--> Artifact
48+
# \---> ...
49+
try:
50+
for i in range(10):
51+
artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name())
52+
lineage_resource_helper.create_association(
53+
source_arn=wide_graph_root_arn, dest_arn=artifact_arn
54+
)
55+
except Exception as e:
56+
print(e)
57+
lineage_resource_helper.clean_all()
58+
assert False
59+
60+
try:
61+
lq = sagemaker.lineage.query.LineageQuery(sagemaker_session)
62+
lq_result = lq.query(start_arns=[wide_graph_root_arn])
63+
lq_result.visualize(path="wideGraph.html")
64+
except Exception as e:
65+
print(e)
66+
lineage_resource_helper.clean_all()
67+
assert False
68+
69+
lineage_resource_helper.clean_all()
70+
71+
72+
@pytest.mark.skip("visualizer load test")
73+
def test_long_graph_visualize(sagemaker_session):
74+
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)
75+
long_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name())
76+
last_arn = long_graph_root_arn
77+
78+
# create long graph
79+
# Artifact -> Artifact -> ... -> Artifact
80+
try:
81+
for i in range(10):
82+
new_artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name())
83+
lineage_resource_helper.create_association(
84+
source_arn=last_arn, dest_arn=new_artifact_arn
85+
)
86+
last_arn = new_artifact_arn
87+
except Exception as e:
88+
print(e)
89+
lineage_resource_helper.clean_all()
90+
assert False
91+
92+
try:
93+
lq = sagemaker.lineage.query.LineageQuery(sagemaker_session)
94+
lq_result = lq.query(
95+
start_arns=[long_graph_root_arn], direction=LineageQueryDirectionEnum.DESCENDANTS
96+
)
97+
# max depth = 10 -> graph rendered only has length of ten (in DESCENDANTS direction)
98+
lq_result.visualize(path="longGraph.html")
99+
except Exception as e:
100+
print(e)
101+
lineage_resource_helper.clean_all()
102+
assert False
103+
104+
lineage_resource_helper.clean_all()
105+
106+
107+
def test_graph_visualize(sagemaker_session):
108+
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)
109+
110+
# create lineage data
111+
# image artifact ------> model artifact(startarn) -> model deploy action -> endpoint context
112+
# /->
113+
# dataset artifact -/
114+
try:
115+
graph_startarn = lineage_resource_helper.create_artifact(
116+
artifact_name=name(), artifact_type="Model"
117+
)
118+
image_artifact = lineage_resource_helper.create_artifact(
119+
artifact_name=name(), artifact_type="Image"
120+
)
121+
lineage_resource_helper.create_association(
122+
source_arn=image_artifact, dest_arn=graph_startarn, association_type="ContributedTo"
123+
)
124+
dataset_artifact = lineage_resource_helper.create_artifact(
125+
artifact_name=name(), artifact_type="DataSet"
126+
)
127+
lineage_resource_helper.create_association(
128+
source_arn=dataset_artifact, dest_arn=graph_startarn, association_type="AssociatedWith"
129+
)
130+
modeldeploy_action = lineage_resource_helper.create_action(
131+
action_name=name(), action_type="ModelDeploy"
132+
)
133+
lineage_resource_helper.create_association(
134+
source_arn=graph_startarn, dest_arn=modeldeploy_action, association_type="ContributedTo"
135+
)
136+
endpoint_context = lineage_resource_helper.create_context(
137+
context_name=name(), context_type="Endpoint"
138+
)
139+
lineage_resource_helper.create_association(
140+
source_arn=modeldeploy_action,
141+
dest_arn=endpoint_context,
142+
association_type="AssociatedWith",
143+
)
144+
time.sleep(1)
145+
except Exception as e:
146+
print(e)
147+
lineage_resource_helper.clean_all()
148+
assert False
149+
150+
# visualize
151+
try:
152+
lq = sagemaker.lineage.query.LineageQuery(sagemaker_session)
153+
lq_result = lq.query(start_arns=[graph_startarn])
154+
lq_result.visualize(path="testGraph.html")
155+
except Exception as e:
156+
print(e)
157+
lineage_resource_helper.clean_all()
158+
assert False
159+
160+
# check generated graph info
161+
try:
162+
fo = open("testGraph.html", "r")
163+
lines = fo.readlines()
164+
for line in lines:
165+
if "nodes = " in line:
166+
node = line
167+
if "edges = " in line:
168+
edge = line
169+
170+
# extract node data
171+
start = node.find("[")
172+
end = node.find("]")
173+
res = node[start + 1 : end].split("}, ")
174+
res = [i + "}" for i in res]
175+
res[-1] = res[-1][:-1]
176+
node_dict = [json.loads(i) for i in res]
177+
178+
# extract edge data
179+
start = edge.find("[")
180+
end = edge.find("]")
181+
res = edge[start + 1 : end].split("}, ")
182+
res = [i + "}" for i in res]
183+
res[-1] = res[-1][:-1]
184+
edge_dict = [json.loads(i) for i in res]
185+
186+
# check node number
187+
assert len(node_dict) == 5
188+
189+
# check startarn
190+
found_value = next(
191+
dictionary for dictionary in node_dict if dictionary["id"] == graph_startarn
192+
)
193+
assert found_value["color"] == "#146eb4"
194+
assert found_value["label"] == "Model"
195+
assert found_value["shape"] == "star"
196+
assert found_value["title"] == "Artifact"
197+
198+
# check image artifact
199+
found_value = next(
200+
dictionary for dictionary in node_dict if dictionary["id"] == image_artifact
201+
)
202+
assert found_value["color"] == "#146eb4"
203+
assert found_value["label"] == "Image"
204+
assert found_value["shape"] == "dot"
205+
assert found_value["title"] == "Artifact"
206+
207+
# check dataset artifact
208+
found_value = next(
209+
dictionary for dictionary in node_dict if dictionary["id"] == dataset_artifact
210+
)
211+
assert found_value["color"] == "#146eb4"
212+
assert found_value["label"] == "DataSet"
213+
assert found_value["shape"] == "dot"
214+
assert found_value["title"] == "Artifact"
215+
216+
# check modeldeploy action
217+
found_value = next(
218+
dictionary for dictionary in node_dict if dictionary["id"] == modeldeploy_action
219+
)
220+
assert found_value["color"] == "#88c396"
221+
assert found_value["label"] == "ModelDeploy"
222+
assert found_value["shape"] == "dot"
223+
assert found_value["title"] == "Action"
224+
225+
# check endpoint context
226+
found_value = next(
227+
dictionary for dictionary in node_dict if dictionary["id"] == endpoint_context
228+
)
229+
assert found_value["color"] == "#ff9900"
230+
assert found_value["label"] == "Endpoint"
231+
assert found_value["shape"] == "dot"
232+
assert found_value["title"] == "Context"
233+
234+
# check edge number
235+
assert len(edge_dict) == 4
236+
237+
# check image_artifact -> model_artifact(startarn) edge
238+
found_value = next(
239+
dictionary for dictionary in edge_dict if dictionary["from"] == image_artifact
240+
)
241+
assert found_value["to"] == graph_startarn
242+
assert found_value["title"] == "ContributedTo"
243+
244+
# check dataset_artifact -> model_artifact(startarn) edge
245+
found_value = next(
246+
dictionary for dictionary in edge_dict if dictionary["from"] == dataset_artifact
247+
)
248+
assert found_value["to"] == graph_startarn
249+
assert found_value["title"] == "AssociatedWith"
250+
251+
# check model_artifact(startarn) -> modeldeploy_action edge
252+
found_value = next(
253+
dictionary for dictionary in edge_dict if dictionary["from"] == graph_startarn
254+
)
255+
assert found_value["to"] == modeldeploy_action
256+
assert found_value["title"] == "ContributedTo"
257+
258+
# check modeldeploy_action -> endpoint_context edge
259+
found_value = next(
260+
dictionary for dictionary in edge_dict if dictionary["from"] == modeldeploy_action
261+
)
262+
assert found_value["to"] == endpoint_context
263+
assert found_value["title"] == "AssociatedWith"
264+
265+
except Exception as e:
266+
print(e)
267+
lineage_resource_helper.clean_all()
268+
os.remove("testGraph.html")
269+
assert False
270+
271+
# delete generated test graph
272+
os.remove("testGraph.html")
273+
# clean lineage data
274+
lineage_resource_helper.clean_all()

0 commit comments

Comments
 (0)