Skip to content

feature: query lineage visualizer for general case #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 142 additions & 13 deletions src/sagemaker/lineage/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ def __str__(self):

Format:
{
'source_arn': 'string', 'destination_arn': 'string',
'source_arn': 'string', 'destination_arn': 'string',
'association_type': 'string'
}

"""
return (str(self.__dict__))
return str(self.__dict__)


class Vertex:
Expand Down Expand Up @@ -147,13 +147,13 @@ def __str__(self):

Format:
{
'arn': 'string', 'lineage_entity': 'string',
'lineage_source': 'string',
'arn': 'string', 'lineage_entity': 'string',
'lineage_source': 'string',
'_session': <sagemaker.session.Session object>
}

"""
return (str(self.__dict__))
return str(self.__dict__)

def to_lineage_object(self):
"""Convert the ``Vertex`` object to its corresponding lineage object.
Expand Down Expand Up @@ -201,6 +201,90 @@ def _artifact_to_lineage_object(self):
return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session)


class DashVisualizer(object):
"""Create object used for visualizing graph using Dash library."""

def __init__(self):
"""Init for DashVisualizer."""
# import visualization packages
self.cyto, self.JupyterDash, self.html = self._import_visual_modules()

def _import_visual_modules(self):
"""Import modules needed for visualization."""
try:
import dash_cytoscape as cyto
except ImportError as e:
print(e)
print("try pip install dash-cytoscape")

try:
from jupyter_dash import JupyterDash
except ImportError as e:
print(e)
print("try pip install jupyter-dash")

try:
from dash import html
except ImportError as e:
print(e)
print("try pip install dash")

return cyto, JupyterDash, html

def _get_app(self, elements):
"""Create JupyterDash app for interactivity on Jupyter notebook."""
app = self.JupyterDash(__name__)
self.cyto.load_extra_layouts()

app.layout = self.html.Div(
[
self.cyto.Cytoscape(
id="cytoscape-layout-1",
elements=elements,
style={"width": "100%", "height": "350px"},
layout={"name": "klay"},
stylesheet=[
{
"selector": "node",
"style": {
"label": "data(label)",
"font-size": "3.5vw",
"height": "10vw",
"width": "10vw",
},
},
{
"selector": "edge",
"style": {
"label": "data(label)",
"color": "gray",
"text-halign": "left",
"text-margin-y": "3px",
"text-margin-x": "-2px",
"font-size": "3%",
"width": "1%",
"curve-style": "taxi",
"target-arrow-color": "gray",
"target-arrow-shape": "triangle",
"line-color": "gray",
"arrow-scale": "0.5",
},
},
],
responsive=True,
)
]
)

return app

def render(self, elements, mode):
"""Render graph for lineage query result."""
app = self._get_app(elements)

return app.run_server(mode=mode)


class LineageQueryResult(object):
"""A wrapper around the results of a lineage query."""

Expand All @@ -226,29 +310,74 @@ def __init__(

def __str__(self):
"""Define string representation of ``LineageQueryResult``.

Format:
{
'edges':[
{
'source_arn': 'string', 'destination_arn': 'string',
'source_arn': 'string', 'destination_arn': 'string',
'association_type': 'string'
},
...
]
'vertices':[
{
'arn': 'string', 'lineage_entity': 'string',
'lineage_source': 'string',
'arn': 'string', 'lineage_entity': 'string',
'lineage_source': 'string',
'_session': <sagemaker.session.Session object>
},
...
]
}

"""
result_dict = vars(self)
return (str({k: [vars(val) for val in v] for k, v in result_dict.items()}))
return str({k: [vars(val) for val in v] for k, v in result_dict.items()})

def _covert_vertices_to_tuples(self):
"""Convert vertices to tuple format for visualizer."""
verts = []
for vert in self.vertices:
verts.append((vert.arn, vert.lineage_source))
return verts

def _covert_edges_to_tuples(self):
"""Convert edges to tuple format for visualizer."""
edges = []
for edge in self.edges:
edges.append((edge.source_arn, edge.destination_arn, edge.association_type))
return edges

def _get_visualization_elements(self):
"""Get elements for visualization."""
verts = self._covert_vertices_to_tuples()
edges = self._covert_edges_to_tuples()

nodes = [
{
"data": {"id": id, "label": label},
}
for id, label in verts
]

edges = [
{"data": {"source": source, "target": target, "label": label}}
for source, target, label in edges
]

elements = nodes + edges

return elements

def visualize(self):
"""Visualize lineage query result."""
elements = self._get_visualization_elements()

dash_vis = DashVisualizer()

dash_server = dash_vis.render(elements=elements, mode="inline")

return dash_server


class LineageFilter(object):
Expand Down