Skip to content

change: add queryLineageResult visualizer load test & integ test #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/extras/local_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
urllib3==1.26.8
docker-compose==1.29.2
docker~=5.0.0
PyYAML==5.4.1
PyYAML==5.4.1
1 change: 1 addition & 0 deletions requirements/extras/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ fabric==2.6.0
requests==2.27.1
sagemaker-experiments==0.1.35
Jinja2==3.0.3
pyvis==0.2.1
4 changes: 2 additions & 2 deletions src/sagemaker/lineage/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def _get_visualization_elements(self):
elements = {"nodes": verts, "edges": edges}
return elements

def visualize(self):
def visualize(self, path="pyvisExample.html"):
"""Visualize lineage query result."""
lineage_graph = {
# nodes can have shape / color
Expand Down Expand Up @@ -398,7 +398,7 @@ def visualize(self):

pyvis_vis = PyvisVisualizer(lineage_graph)
elements = self._get_visualization_elements()
return pyvis_vis.render(elements=elements)
return pyvis_vis.render(elements=elements, path=path)


class LineageFilter(object):
Expand Down
15 changes: 15 additions & 0 deletions tests/integ/sagemaker/lineage/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest
import logging
import uuid
import json
from sagemaker.lineage import (
action,
context,
Expand Down Expand Up @@ -891,3 +892,17 @@ def _deploy_static_endpoint(execution_arn, sagemaker_session):
pass
else:
raise (e)


@pytest.fixture
def extract_data_from_html():
def _method(data):
start = data.find("[")
end = data.find("]")
res = data[start + 1 : end].split("}, ")
res = [i + "}" for i in res]
res[-1] = res[-1][:-1]
data_dict = [json.loads(i) for i in res]
return data_dict

return _method
87 changes: 87 additions & 0 deletions tests/integ/sagemaker/lineage/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,90 @@ def visit(arn, visited: set):

ret = []
return visit(start_arn, set())


class LineageResourceHelper:
def __init__(self, sagemaker_session):
self.client = sagemaker_session.sagemaker_client
self.artifacts = []
self.actions = []
self.contexts = []
self.associations = []

def create_artifact(self, artifact_name, artifact_type="Dataset"):
response = self.client.create_artifact(
ArtifactName=artifact_name,
Source={
"SourceUri": "Test-artifact-" + artifact_name,
"SourceTypes": [
{"SourceIdType": "S3ETag", "Value": "Test-artifact-sourceId-value"},
],
},
ArtifactType=artifact_type,
)
self.artifacts.append(response["ArtifactArn"])

return response["ArtifactArn"]

def create_action(self, action_name, action_type="ModelDeployment"):
response = self.client.create_action(
ActionName=action_name,
Source={
"SourceUri": "Test-action-" + action_name,
"SourceType": "S3ETag",
"SourceId": "Test-action-sourceId-value",
},
ActionType=action_type,
)
self.actions.append(response["ActionArn"])

return response["ActionArn"]

def create_context(self, context_name, context_type="Endpoint"):
response = self.client.create_context(
ContextName=context_name,
Source={
"SourceUri": "Test-context-" + context_name,
"SourceType": "S3ETag",
"SourceId": "Test-context-sourceId-value",
},
ContextType=context_type,
)
self.contexts.append(response["ContextArn"])

return response["ContextArn"]

def create_association(self, source_arn, dest_arn, association_type="AssociatedWith"):
response = self.client.add_association(
SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_type
)
if "SourceArn" in response.keys():
self.associations.append((source_arn, dest_arn))
return True
else:
return False

def clean_all(self):
for source, dest in self.associations:
try:
self.client.delete_association(SourceArn=source, DestinationArn=dest)
except Exception as e:
print("skipped " + str(e))

for artifact_arn in self.artifacts:
try:
self.client.delete_artifact(ArtifactArn=artifact_arn)
except Exception as e:
print("skipped " + str(e))

for action_arn in self.actions:
try:
self.client.delete_action(ActionArn=action_arn)
except Exception as e:
print("skipped " + str(e))

for context_arn in self.contexts:
try:
self.client.delete_context(ContextArn=context_arn)
except Exception as e:
print("skipped " + str(e))
232 changes: 232 additions & 0 deletions tests/integ/sagemaker/lineage/test_lineage_visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains code to test SageMaker ``LineageQueryResult.visualize()``"""
from __future__ import absolute_import
import time
import os

import pytest

import sagemaker.lineage.query
from sagemaker.lineage.query import LineageQueryDirectionEnum
from tests.integ.sagemaker.lineage.helpers import name, LineageResourceHelper


def test_LineageResourceHelper(sagemaker_session):
# check if LineageResourceHelper works properly
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)
try:
art1 = lineage_resource_helper.create_artifact(artifact_name=name())
art2 = lineage_resource_helper.create_artifact(artifact_name=name())
lineage_resource_helper.create_association(source_arn=art1, dest_arn=art2)
lineage_resource_helper.clean_all()
except Exception as e:
print(e)
assert False


@pytest.mark.skip("visualizer load test")
def test_wide_graph_visualize(sagemaker_session):
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)
wide_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name())

# create wide graph
# Artifact ----> Artifact
# \ \ \-> Artifact
# \ \--> Artifact
# \---> ...
try:
for i in range(200):
artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name())
lineage_resource_helper.create_association(
source_arn=wide_graph_root_arn, dest_arn=artifact_arn
)

lq = sagemaker.lineage.query.LineageQuery(sagemaker_session)
lq_result = lq.query(start_arns=[wide_graph_root_arn])
lq_result.visualize(path="wideGraph.html")

except Exception as e:
print(e)
assert False

finally:
lineage_resource_helper.clean_all()


@pytest.mark.skip("visualizer load test")
def test_long_graph_visualize(sagemaker_session):
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)
long_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name())
last_arn = long_graph_root_arn

# create long graph
# Artifact -> Artifact -> ... -> Artifact
try:
for i in range(10):
new_artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name())
lineage_resource_helper.create_association(
source_arn=last_arn, dest_arn=new_artifact_arn
)
last_arn = new_artifact_arn

lq = sagemaker.lineage.query.LineageQuery(sagemaker_session)
lq_result = lq.query(
start_arns=[long_graph_root_arn], direction=LineageQueryDirectionEnum.DESCENDANTS
)
# max depth = 10 -> graph rendered only has length of ten (in DESCENDANTS direction)
lq_result.visualize(path="longGraph.html")

except Exception as e:
print(e)
assert False

finally:
lineage_resource_helper.clean_all()


def test_graph_visualize(sagemaker_session, extract_data_from_html):
lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session)

# create lineage data
# image artifact ------> model artifact(startarn) -> model deploy action -> endpoint context
# /->
# dataset artifact -/
try:
graph_startarn = lineage_resource_helper.create_artifact(
artifact_name=name(), artifact_type="Model"
)
image_artifact = lineage_resource_helper.create_artifact(
artifact_name=name(), artifact_type="Image"
)
lineage_resource_helper.create_association(
source_arn=image_artifact, dest_arn=graph_startarn, association_type="ContributedTo"
)
dataset_artifact = lineage_resource_helper.create_artifact(
artifact_name=name(), artifact_type="DataSet"
)
lineage_resource_helper.create_association(
source_arn=dataset_artifact, dest_arn=graph_startarn, association_type="AssociatedWith"
)
modeldeploy_action = lineage_resource_helper.create_action(
action_name=name(), action_type="ModelDeploy"
)
lineage_resource_helper.create_association(
source_arn=graph_startarn, dest_arn=modeldeploy_action, association_type="ContributedTo"
)
endpoint_context = lineage_resource_helper.create_context(
context_name=name(), context_type="Endpoint"
)
lineage_resource_helper.create_association(
source_arn=modeldeploy_action,
dest_arn=endpoint_context,
association_type="AssociatedWith",
)
time.sleep(3)

# visualize
lq = sagemaker.lineage.query.LineageQuery(sagemaker_session)
lq_result = lq.query(start_arns=[graph_startarn])
lq_result.visualize(path="testGraph.html")

# check generated graph info
fo = open("testGraph.html", "r")
lines = fo.readlines()
for line in lines:
if "nodes = " in line:
node = line
if "edges = " in line:
edge = line

node_dict = extract_data_from_html(node)
edge_dict = extract_data_from_html(edge)

# check node number
assert len(node_dict) == 5

expected_nodes = {
graph_startarn: {
"color": "#146eb4",
"label": "Model",
"shape": "star",
"title": "Artifact",
},
image_artifact: {
"color": "#146eb4",
"label": "Image",
"shape": "dot",
"title": "Artifact",
},
dataset_artifact: {
"color": "#146eb4",
"label": "DataSet",
"shape": "dot",
"title": "Artifact",
},
modeldeploy_action: {
"color": "#88c396",
"label": "ModelDeploy",
"shape": "dot",
"title": "Action",
},
endpoint_context: {
"color": "#ff9900",
"label": "Endpoint",
"shape": "dot",
"title": "Context",
},
}

# check node properties
for node in node_dict:
for label, val in expected_nodes[node["id"]].items():
assert node[label] == val

# check edge number
assert len(edge_dict) == 4

expected_edges = {
image_artifact: {
"from": image_artifact,
"to": graph_startarn,
"title": "ContributedTo",
},
dataset_artifact: {
"from": dataset_artifact,
"to": graph_startarn,
"title": "AssociatedWith",
},
graph_startarn: {
"from": graph_startarn,
"to": modeldeploy_action,
"title": "ContributedTo",
},
modeldeploy_action: {
"from": modeldeploy_action,
"to": endpoint_context,
"title": "AssociatedWith",
},
}

# check edge properties
for edge in edge_dict:
for label, val in expected_edges[edge["from"]].items():
assert edge[label] == val

except Exception as e:
print(e)
assert False

finally:
lineage_resource_helper.clean_all()
os.remove("testGraph.html")