Skip to content

change: add queryLineageResult visualizer load test & integ test #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/extras/local_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
urllib3==1.26.8
docker-compose==1.29.2
docker~=5.0.0
PyYAML==5.4.1
PyYAML==5.4.1
1 change: 1 addition & 0 deletions requirements/extras/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ fabric==2.6.0
requests==2.27.1
sagemaker-experiments==0.1.35
Jinja2==3.0.3
pyvis==0.2.1
4 changes: 2 additions & 2 deletions src/sagemaker/lineage/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def _get_visualization_elements(self):
elements = {"nodes": verts, "edges": edges}
return elements

def visualize(self):
def visualize(self, path="pyvisExample.html"):
"""Visualize lineage query result."""
lineage_graph = {
# nodes can have shape / color
Expand Down Expand Up @@ -398,7 +398,7 @@ def visualize(self):

pyvis_vis = PyvisVisualizer(lineage_graph)
elements = self._get_visualization_elements()
return pyvis_vis.render(elements=elements)
return pyvis_vis.render(elements=elements, path=path)


class LineageFilter(object):
Expand Down
91 changes: 91 additions & 0 deletions tests/integ/sagemaker/lineage/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,94 @@ def visit(arn, visited: set):

ret = []
return visit(start_arn, set())


class LineageResourceHelper:
def __init__(self, sagemaker_session):
self.client = sagemaker_session.sagemaker_client
self.artifacts = []
self.actions = []
self.contexts = []
self.associations = []

def create_artifact(self, artifact_name, artifact_type="Dataset"):
response = self.client.create_artifact(
ArtifactName=artifact_name,
Source={
"SourceUri": "Test-artifact-" + artifact_name,
"SourceTypes": [
{"SourceIdType": "S3ETag", "Value": "Test-artifact-sourceId-value"},
],
},
ArtifactType=artifact_type,
)
self.artifacts.append(response["ArtifactArn"])

return response["ArtifactArn"]

def create_action(self, action_name, action_type="ModelDeployment"):
response = self.client.create_action(
ActionName=action_name,
Source={
"SourceUri": "Test-action-" + action_name,
"SourceType": "S3ETag",
"SourceId": "Test-action-sourceId-value",
},
ActionType=action_type,
)
self.actions.append(response["ActionArn"])

return response["ActionArn"]

def create_context(self, context_name, context_type="Endpoint"):
response = self.client.create_context(
ContextName=context_name,
Source={
"SourceUri": "Test-context-" + context_name,
"SourceType": "S3ETag",
"SourceId": "Test-context-sourceId-value",
},
ContextType=context_type,
)
self.contexts.append(response["ContextArn"])

return response["ContextArn"]

def create_association(self, source_arn, dest_arn, association_type="AssociatedWith"):
response = self.client.add_association(
SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_type
)
if "SourceArn" in response.keys():
self.associations.append((source_arn, dest_arn))
return True
else:
return False

def clean_all(self):
for source, dest in self.associations:
try:
self.client.delete_association(SourceArn=source, DestinationArn=dest)
time.sleep(0.5)
except Exception as e:
print("skipped " + str(e))

for artifact_arn in self.artifacts:
try:
self.client.delete_artifact(ArtifactArn=artifact_arn)
time.sleep(0.5)
except Exception as e:
print("skipped " + str(e))

for action_arn in self.actions:
try:
self.client.delete_action(ActionArn=action_arn)
time.sleep(0.5)
except Exception as e:
print("skipped " + str(e))

for context_arn in self.contexts:
try:
self.client.delete_context(ContextArn=context_arn)
time.sleep(0.5)
except Exception as e:
print("skipped " + str(e))
274 changes: 274 additions & 0 deletions tests/integ/sagemaker/lineage/test_lineage_visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains code to test SageMaker ``LineageQueryResult.visualize()``"""
from __future__ import absolute_import
import time
import json
import os

import pytest

import sagemaker.lineage.query
from sagemaker.lineage.query import LineageQueryDirectionEnum
from tests.integ.sagemaker.lineage.helpers import name, LineageResourceHelper


def test_LineageResourceHelper(sagemaker_session):
# check if LineageResourceHelper works properly
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)
try:
art1 = lineage_resource_helper.create_artifact(artifact_name=name())
art2 = lineage_resource_helper.create_artifact(artifact_name=name())
lineage_resource_helper.create_association(source_arn=art1, dest_arn=art2)
lineage_resource_helper.clean_all()
except Exception as e:
print(e)
assert False


@pytest.mark.skip("visualizer load test")
def test_wide_graph_visualize(sagemaker_session):
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)
wide_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name())

# create wide graph
# Artifact ----> Artifact
# \ \ \-> Artifact
# \ \--> Artifact
# \---> ...
try:
for i in range(10):
artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name())
lineage_resource_helper.create_association(
source_arn=wide_graph_root_arn, dest_arn=artifact_arn
)
except Exception as e:
print(e)
lineage_resource_helper.clean_all()
assert False

try:
lq = sagemaker.lineage.query.LineageQuery(sagemaker_session)
lq_result = lq.query(start_arns=[wide_graph_root_arn])
lq_result.visualize(path="wideGraph.html")
except Exception as e:
print(e)
lineage_resource_helper.clean_all()
assert False

lineage_resource_helper.clean_all()


@pytest.mark.skip("visualizer load test")
def test_long_graph_visualize(sagemaker_session):
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)
long_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name())
last_arn = long_graph_root_arn

# create long graph
# Artifact -> Artifact -> ... -> Artifact
try:
for i in range(10):
new_artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name())
lineage_resource_helper.create_association(
source_arn=last_arn, dest_arn=new_artifact_arn
)
last_arn = new_artifact_arn
except Exception as e:
print(e)
lineage_resource_helper.clean_all()
assert False

try:
lq = sagemaker.lineage.query.LineageQuery(sagemaker_session)
lq_result = lq.query(
start_arns=[long_graph_root_arn], direction=LineageQueryDirectionEnum.DESCENDANTS
)
# max depth = 10 -> graph rendered only has length of ten (in DESCENDANTS direction)
lq_result.visualize(path="longGraph.html")
except Exception as e:
print(e)
lineage_resource_helper.clean_all()
assert False

lineage_resource_helper.clean_all()


def test_graph_visualize(sagemaker_session):
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)

# create lineage data
# image artifact ------> model artifact(startarn) -> model deploy action -> endpoint context
# /->
# dataset artifact -/
try:
graph_startarn = lineage_resource_helper.create_artifact(
artifact_name=name(), artifact_type="Model"
)
image_artifact = lineage_resource_helper.create_artifact(
artifact_name=name(), artifact_type="Image"
)
lineage_resource_helper.create_association(
source_arn=image_artifact, dest_arn=graph_startarn, association_type="ContributedTo"
)
dataset_artifact = lineage_resource_helper.create_artifact(
artifact_name=name(), artifact_type="DataSet"
)
lineage_resource_helper.create_association(
source_arn=dataset_artifact, dest_arn=graph_startarn, association_type="AssociatedWith"
)
modeldeploy_action = lineage_resource_helper.create_action(
action_name=name(), action_type="ModelDeploy"
)
lineage_resource_helper.create_association(
source_arn=graph_startarn, dest_arn=modeldeploy_action, association_type="ContributedTo"
)
endpoint_context = lineage_resource_helper.create_context(
context_name=name(), context_type="Endpoint"
)
lineage_resource_helper.create_association(
source_arn=modeldeploy_action,
dest_arn=endpoint_context,
association_type="AssociatedWith",
)
time.sleep(1)
except Exception as e:
print(e)
lineage_resource_helper.clean_all()
assert False

# visualize
try:
lq = sagemaker.lineage.query.LineageQuery(sagemaker_session)
lq_result = lq.query(start_arns=[graph_startarn])
lq_result.visualize(path="testGraph.html")
except Exception as e:
print(e)
lineage_resource_helper.clean_all()
assert False

# check generated graph info
try:
fo = open("testGraph.html", "r")
lines = fo.readlines()
for line in lines:
if "nodes = " in line:
node = line
if "edges = " in line:
edge = line

# extract node data
start = node.find("[")
end = node.find("]")
res = node[start + 1 : end].split("}, ")
res = [i + "}" for i in res]
res[-1] = res[-1][:-1]
node_dict = [json.loads(i) for i in res]

# extract edge data
start = edge.find("[")
end = edge.find("]")
res = edge[start + 1 : end].split("}, ")
res = [i + "}" for i in res]
res[-1] = res[-1][:-1]
edge_dict = [json.loads(i) for i in res]

# check node number
assert len(node_dict) == 5

# check startarn
found_value = next(
dictionary for dictionary in node_dict if dictionary["id"] == graph_startarn
)
assert found_value["color"] == "#146eb4"
assert found_value["label"] == "Model"
assert found_value["shape"] == "star"
assert found_value["title"] == "Artifact"

# check image artifact
found_value = next(
dictionary for dictionary in node_dict if dictionary["id"] == image_artifact
)
assert found_value["color"] == "#146eb4"
assert found_value["label"] == "Image"
assert found_value["shape"] == "dot"
assert found_value["title"] == "Artifact"

# check dataset artifact
found_value = next(
dictionary for dictionary in node_dict if dictionary["id"] == dataset_artifact
)
assert found_value["color"] == "#146eb4"
assert found_value["label"] == "DataSet"
assert found_value["shape"] == "dot"
assert found_value["title"] == "Artifact"

# check modeldeploy action
found_value = next(
dictionary for dictionary in node_dict if dictionary["id"] == modeldeploy_action
)
assert found_value["color"] == "#88c396"
assert found_value["label"] == "ModelDeploy"
assert found_value["shape"] == "dot"
assert found_value["title"] == "Action"

# check endpoint context
found_value = next(
dictionary for dictionary in node_dict if dictionary["id"] == endpoint_context
)
assert found_value["color"] == "#ff9900"
assert found_value["label"] == "Endpoint"
assert found_value["shape"] == "dot"
assert found_value["title"] == "Context"

# check edge number
assert len(edge_dict) == 4

# check image_artifact -> model_artifact(startarn) edge
found_value = next(
dictionary for dictionary in edge_dict if dictionary["from"] == image_artifact
)
assert found_value["to"] == graph_startarn
assert found_value["title"] == "ContributedTo"

# check dataset_artifact -> model_artifact(startarn) edge
found_value = next(
dictionary for dictionary in edge_dict if dictionary["from"] == dataset_artifact
)
assert found_value["to"] == graph_startarn
assert found_value["title"] == "AssociatedWith"

# check model_artifact(startarn) -> modeldeploy_action edge
found_value = next(
dictionary for dictionary in edge_dict if dictionary["from"] == graph_startarn
)
assert found_value["to"] == modeldeploy_action
assert found_value["title"] == "ContributedTo"

# check modeldeploy_action -> endpoint_context edge
found_value = next(
dictionary for dictionary in edge_dict if dictionary["from"] == modeldeploy_action
)
assert found_value["to"] == endpoint_context
assert found_value["title"] == "AssociatedWith"

except Exception as e:
print(e)
lineage_resource_helper.clean_all()
os.remove("testGraph.html")
assert False

# delete generated test graph
os.remove("testGraph.html")
# clean lineage data
lineage_resource_helper.clean_all()