Skip to content

Commit b6b5ee4

Browse files
author
Yi-Ting Lee
committed
PyvisVisualizer added
1 parent 0b35553 commit b6b5ee4

File tree

1 file changed

+101
-3
lines changed

1 file changed

+101
-3
lines changed

src/sagemaker/lineage/query.py

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,78 @@ def render(self, elements, mode):
390390

391391
return app.run_server(mode=mode)
392392

393+
class PyvisVisualizer(object):
394+
"""Create object used for visualizing graph using Pyvis library."""
395+
396+
def __init__(self, graph_styles):
397+
"""Init for PyvisVisualizer."""
398+
# import visualization packages
399+
(
400+
self.pyvis,
401+
self.Network,
402+
self.Options,
403+
) = self._import_visual_modules()
404+
405+
self.graph_styles = graph_styles
406+
407+
def _import_visual_modules(self):
408+
import pyvis
409+
from pyvis.network import Network
410+
from pyvis.options import Options
411+
412+
return pyvis, Network, Options
413+
414+
def _get_options(self):
415+
options = """
416+
var options = {
417+
"configure":{
418+
"enabled": true
419+
},
420+
"layout": {
421+
"hierarchical": {
422+
"enabled": true,
423+
"blockShifting": false,
424+
"direction": "LR",
425+
"sortMethod": "directed",
426+
"shakeTowards": "roots"
427+
}
428+
},
429+
"interaction": {
430+
"multiselect": true,
431+
"navigationButtons": true
432+
},
433+
"physics": {
434+
"enabled": false,
435+
"hierarchicalRepulsion": {
436+
"centralGravity": 0,
437+
"avoidOverlap": null
438+
},
439+
"minVelocity": 0.75,
440+
"solver": "hierarchicalRepulsion"
441+
}
442+
}
443+
"""
444+
return options
445+
446+
def _node_color(self, n):
447+
return self.graph_styles[n[2]]["style"]["background-color"]
448+
449+
def render(self, elements):
450+
net = self.Network(height='500px', width='100%', notebook = True, directed = True)
451+
options = self._get_options()
452+
net.set_options(options)
453+
454+
for n in elements["nodes"]:
455+
if(n[3]==True): # startarn
456+
net.add_node(n[0], label=n[1], title=n[1], color=self._node_color(n), shape="star")
457+
else:
458+
net.add_node(n[0], label=n[1], title=n[1], color=self._node_color(n))
459+
460+
for e in elements["edges"]:
461+
print(e)
462+
net.add_edge(e[0], e[1], title=e[2])
463+
464+
return net.show('pyvisExample.html')
393465

394466
class LineageQueryResult(object):
395467
"""A wrapper around the results of a lineage query."""
@@ -469,6 +541,18 @@ def _covert_edges_to_tuples(self):
469541
edges.append((edge.source_arn, edge.destination_arn, edge.association_type))
470542
return edges
471543

544+
def _pyvis_covert_vertices_to_tuples(self):
545+
"""Convert vertices to tuple format for visualizer."""
546+
verts = []
547+
# get vertex info in the form of (id, label, class)
548+
for vert in self.vertices:
549+
if vert.arn in self.startarn:
550+
# add "startarn" class to node if arn is a startarn
551+
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, True))
552+
else:
553+
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity, False))
554+
return verts
555+
472556
def _get_visualization_elements(self):
473557
"""Get elements for visualization."""
474558
# get vertices and edges info for graph
@@ -488,6 +572,16 @@ def _get_visualization_elements(self):
488572

489573
return elements
490574

575+
def _get_pyvis_visualization_elements(self):
576+
verts = self._pyvis_covert_vertices_to_tuples()
577+
edges = self._covert_edges_to_tuples()
578+
579+
elements = {
580+
"nodes": verts,
581+
"edges": edges
582+
}
583+
return elements
584+
491585
def visualize(self):
492586
"""Visualize lineage query result."""
493587
elements = self._get_visualization_elements()
@@ -523,11 +617,15 @@ def visualize(self):
523617
}
524618

525619
# initialize DashVisualizer instance to render graph & interactive components
526-
dash_vis = DashVisualizer(lineage_graph)
620+
# dash_vis = DashVisualizer(lineage_graph)
621+
622+
# dash_server = dash_vis.render(elements=elements, mode="inline")
527623

528-
dash_server = dash_vis.render(elements=elements, mode="inline")
624+
# return dash_server
529625

530-
return dash_server
626+
pyvis_vis = PyvisVisualizer(lineage_graph)
627+
elements = self._get_pyvis_visualization_elements()
628+
return pyvis_vis.render(elements=elements)
531629

532630

533631
class LineageFilter(object):

0 commit comments

Comments
 (0)