Skip to content

Commit baeaace

Browse files
author
Yi-Ting Lee
committed
change: change visualization to using pyvis library
1 parent 794fb57 commit baeaace

File tree

1 file changed

+30
-238
lines changed

1 file changed

+30
-238
lines changed

src/sagemaker/lineage/query.py

Lines changed: 30 additions & 238 deletions
Original file line numberDiff line numberDiff line change
@@ -201,218 +201,47 @@ def _artifact_to_lineage_object(self):
201201
return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
202202

203203

204-
class DashVisualizer(object):
205-
"""Create object used for visualizing graph using Dash library."""
204+
class PyvisVisualizer(object):
205+
"""Create object used for visualizing graph using Pyvis library."""
206206

207207
def __init__(self, graph_styles):
208-
"""Init for DashVisualizer."""
208+
"""Init for PyvisVisualizer."""
209209
# import visualization packages
210210
(
211-
self.cyto,
212-
self.JupyterDash,
213-
self.html,
214-
self.Input,
215-
self.Output,
211+
self.pyvis,
212+
self.Network,
213+
self.Options,
216214
) = self._import_visual_modules()
217215

218216
self.graph_styles = graph_styles
219217

220218
def _import_visual_modules(self):
221219
"""Import modules needed for visualization."""
222220
try:
223-
import dash_cytoscape as cyto
221+
import pyvis
224222
except ImportError as e:
225223
print(e)
226-
print("Try: pip install dash-cytoscape")
224+
print("Try: pip install pyvis")
227225
raise
228226

229227
try:
230-
from jupyter_dash import JupyterDash
228+
from pyvis.network import Network
231229
except ImportError as e:
232230
print(e)
233-
print("Try: pip install jupyter-dash")
231+
print("Try: pip install pyvis")
234232
raise
235233

236234
try:
237-
from dash import html
235+
from pyvis.options import Options
238236
except ImportError as e:
239237
print(e)
240-
print("Try: pip install dash")
238+
print("Try: pip install pyvis")
241239
raise
242240

243-
try:
244-
from dash.dependencies import Input, Output
245-
except ImportError as e:
246-
print(e)
247-
print("Try: pip install dash")
248-
raise
249-
250-
return cyto, JupyterDash, html, Input, Output
251-
252-
def _create_legend_component(self, style):
253-
"""Create legend component div."""
254-
text = style["name"]
255-
symbol = ""
256-
color = "#ffffff"
257-
if style["isShape"] == "False":
258-
color = style["style"]["background-color"]
259-
else:
260-
symbol = style["symbol"]
261-
return self.html.Div(
262-
[
263-
self.html.Div(
264-
symbol,
265-
style={
266-
"background-color": color,
267-
"width": "1.5vw",
268-
"height": "1.5vw",
269-
"display": "inline-block",
270-
"font-size": "1.5vw",
271-
},
272-
),
273-
self.html.Div(
274-
style={
275-
"width": "0.5vw",
276-
"height": "1.5vw",
277-
"display": "inline-block",
278-
}
279-
),
280-
self.html.Div(
281-
text,
282-
style={"display": "inline-block", "font-size": "1.5vw"},
283-
),
284-
]
285-
)
286-
287-
def _create_entity_selector(self, entity_name, style):
288-
"""Create selector for each lineage entity."""
289-
return {"selector": "." + entity_name, "style": style["style"]}
290-
291-
def _get_app(self, elements):
292-
"""Create JupyterDash app for interactivity on Jupyter notebook."""
293-
app = self.JupyterDash(__name__)
294-
self.cyto.load_extra_layouts()
295-
296-
app.layout = self.html.Div(
297-
[
298-
# graph section
299-
self.cyto.Cytoscape(
300-
id="cytoscape-graph",
301-
elements=elements,
302-
style={
303-
"width": "84%",
304-
"height": "350px",
305-
"display": "inline-block",
306-
"border-width": "1vw",
307-
"border-color": "#232f3e",
308-
},
309-
layout={"name": "klay"},
310-
stylesheet=[
311-
{
312-
"selector": "node",
313-
"style": {
314-
"label": "data(label)",
315-
"font-size": "3.5vw",
316-
"height": "10vw",
317-
"width": "10vw",
318-
"border-width": "0.8",
319-
"border-opacity": "0",
320-
"border-color": "#232f3e",
321-
"font-family": "verdana",
322-
},
323-
},
324-
{
325-
"selector": "edge",
326-
"style": {
327-
"label": "data(label)",
328-
"color": "gray",
329-
"text-halign": "left",
330-
"text-margin-y": "2.5",
331-
"font-size": "3",
332-
"width": "1",
333-
"curve-style": "bezier",
334-
"control-point-step-size": "15",
335-
"target-arrow-color": "gray",
336-
"target-arrow-shape": "triangle",
337-
"line-color": "gray",
338-
"arrow-scale": "0.5",
339-
"font-family": "verdana",
340-
},
341-
},
342-
{"selector": ".select", "style": {"border-opacity": "0.7"}},
343-
]
344-
+ [self._create_entity_selector(k, v) for k, v in self.graph_styles.items()],
345-
responsive=True,
346-
),
347-
self.html.Div(
348-
style={
349-
"width": "0.5%",
350-
"display": "inline-block",
351-
"font-size": "1vw",
352-
"font-family": "verdana",
353-
"vertical-align": "top",
354-
},
355-
),
356-
# legend section
357-
self.html.Div(
358-
[self._create_legend_component(v) for k, v in self.graph_styles.items()],
359-
style={
360-
"display": "inline-block",
361-
"font-size": "1vw",
362-
"font-family": "verdana",
363-
"vertical-align": "top",
364-
},
365-
),
366-
]
367-
)
368-
369-
@app.callback(
370-
self.Output("cytoscape-graph", "elements"),
371-
self.Input("cytoscape-graph", "tapNodeData"),
372-
self.Input("cytoscape-graph", "elements"),
373-
)
374-
def selectNode(tapData, elements):
375-
for n in elements:
376-
if tapData is not None and n["data"]["id"] == tapData["id"]:
377-
# if is tapped node, add "select" class to node
378-
n["classes"] += " select"
379-
elif "classes" in n:
380-
# remove "select" class in "classes" if node not selected
381-
n["classes"] = n["classes"].replace("select", "")
382-
383-
return elements
384-
385-
return app
386-
387-
def render(self, elements, mode):
388-
"""Render graph for lineage query result."""
389-
app = self._get_app(elements)
390-
391-
return app.run_server(mode=mode)
392-
393-
class PyvisVisualizer(object):
394-
"""Create object used for visualizing graph using Pyvis library."""
395-
396-
def __init__(self, graph_styles):
397-
"""Init for PyvisVisualizer."""
398-
# import visualization packages
399-
(
400-
self.pyvis,
401-
self.Network,
402-
self.Options,
403-
) = self._import_visual_modules()
404-
405-
self.graph_styles = graph_styles
406-
407-
def _import_visual_modules(self):
408-
import pyvis
409-
from pyvis.network import Network
410-
from pyvis.options import Options
411-
# No module named 'pyvis'
412-
413241
return pyvis, Network, Options
414242

415243
def _get_options(self):
244+
"""Get pyvis graph options."""
416245
options = """
417246
var options = {
418247
"configure":{
@@ -445,24 +274,29 @@ def _get_options(self):
445274
return options
446275

447276
def _node_color(self, n):
277+
"""Return node color by background-color specified in graph styles."""
448278
return self.graph_styles[n[2]]["style"]["background-color"]
449279

450-
def render(self, elements):
451-
net = self.Network(height='500px', width='100%', notebook = True, directed = True)
280+
def render(self, elements, path="pyvisExample.html"):
281+
"""Render graph for lineage query result."""
282+
net = self.Network(height="500px", width="100%", notebook=True, directed=True)
452283
options = self._get_options()
453-
net.set_options(options)
284+
net.set_options(options)
454285

286+
# add nodes to graph
455287
for n in elements["nodes"]:
456-
if(n[3]==True): # startarn
457-
net.add_node(n[0], label=n[1], title=n[1], color=self._node_color(n), shape="star")
288+
if n[3]: # startarn
289+
net.add_node(n[0], label=n[1], title=n[2], color=self._node_color(n), shape="star")
458290
else:
459-
net.add_node(n[0], label=n[1], title=n[1], color=self._node_color(n))
291+
net.add_node(n[0], label=n[1], title=n[2], color=self._node_color(n))
460292

293+
# add edges to graph
461294
for e in elements["edges"]:
462295
print(e)
463-
net.add_edge(e[0], e[1], title=e[2])
296+
net.add_edge(e[0], e[1], title=e[2])
297+
298+
return net.show(path)
464299

465-
return net.show('pyvisExample.html')
466300

467301
class LineageQueryResult(object):
468302
"""A wrapper around the results of a lineage query."""
@@ -522,18 +356,6 @@ def __str__(self):
522356
result_dict = vars(self)
523357
return str({k: [str(val) for val in v] for k, v in result_dict.items()})
524358

525-
def _covert_vertices_to_tuples(self):
526-
"""Convert vertices to tuple format for visualizer."""
527-
verts = []
528-
# get vertex info in the form of (id, label, class)
529-
for vert in self.vertices:
530-
if vert.arn in self.startarn:
531-
# add "startarn" class to node if arn is a startarn
532-
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity + " startarn"))
533-
else:
534-
verts.append((vert.arn, vert.lineage_source, vert.lineage_entity))
535-
return verts
536-
537359
def _covert_edges_to_tuples(self):
538360
"""Convert edges to tuple format for visualizer."""
539361
edges = []
@@ -542,7 +364,7 @@ def _covert_edges_to_tuples(self):
542364
edges.append((edge.source_arn, edge.destination_arn, edge.association_type))
543365
return edges
544366

545-
def _pyvis_covert_vertices_to_tuples(self):
367+
def _covert_vertices_to_tuples(self):
546368
"""Convert vertices to tuple format for visualizer."""
547369
verts = []
548370
# get vertex info in the form of (id, label, class)
@@ -555,38 +377,15 @@ def _pyvis_covert_vertices_to_tuples(self):
555377
return verts
556378

557379
def _get_visualization_elements(self):
558-
"""Get elements for visualization."""
559-
# get vertices and edges info for graph
380+
"""Get elements(nodes+edges) for visualization."""
560381
verts = self._covert_vertices_to_tuples()
561382
edges = self._covert_edges_to_tuples()
562383

563-
nodes = [
564-
{"data": {"id": id, "label": label}, "classes": classes} for id, label, classes in verts
565-
]
566-
567-
edges = [
568-
{"data": {"source": source, "target": target, "label": label}}
569-
for source, target, label in edges
570-
]
571-
572-
elements = nodes + edges
573-
574-
return elements
575-
576-
def _get_pyvis_visualization_elements(self):
577-
verts = self._pyvis_covert_vertices_to_tuples()
578-
edges = self._covert_edges_to_tuples()
579-
580-
elements = {
581-
"nodes": verts,
582-
"edges": edges
583-
}
384+
elements = {"nodes": verts, "edges": edges}
584385
return elements
585386

586387
def visualize(self):
587388
"""Visualize lineage query result."""
588-
elements = self._get_visualization_elements()
589-
590389
lineage_graph = {
591390
# nodes can have shape / color
592391
"TrialComponent": {
@@ -617,15 +416,8 @@ def visualize(self):
617416
},
618417
}
619418

620-
# initialize DashVisualizer instance to render graph & interactive components
621-
# dash_vis = DashVisualizer(lineage_graph)
622-
623-
# dash_server = dash_vis.render(elements=elements, mode="inline")
624-
625-
# return dash_server
626-
627419
pyvis_vis = PyvisVisualizer(lineage_graph)
628-
elements = self._get_pyvis_visualization_elements()
420+
elements = self._get_visualization_elements()
629421
return pyvis_vis.render(elements=elements)
630422

631423

0 commit comments

Comments
 (0)