Skip to content

Commit cbe445e

Browse files
authored
Merge pull request #3 from ytlee93/master
feature: query lineage visualizer for general case
2 parents e0c59c2 + 79b193a commit cbe445e

File tree

1 file changed

+142
-13
lines changed

1 file changed

+142
-13
lines changed

src/sagemaker/lineage/query.py

Lines changed: 142 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,12 @@ def __str__(self):
9797
9898
Format:
9999
{
100-
'source_arn': 'string', 'destination_arn': 'string',
100+
'source_arn': 'string', 'destination_arn': 'string',
101101
'association_type': 'string'
102102
}
103-
103+
104104
"""
105-
return (str(self.__dict__))
105+
return str(self.__dict__)
106106

107107

108108
class Vertex:
@@ -147,13 +147,13 @@ def __str__(self):
147147
148148
Format:
149149
{
150-
'arn': 'string', 'lineage_entity': 'string',
151-
'lineage_source': 'string',
150+
'arn': 'string', 'lineage_entity': 'string',
151+
'lineage_source': 'string',
152152
'_session': <sagemaker.session.Session object>
153153
}
154-
154+
155155
"""
156-
return (str(self.__dict__))
156+
return str(self.__dict__)
157157

158158
def to_lineage_object(self):
159159
"""Convert the ``Vertex`` object to its corresponding lineage object.
@@ -201,6 +201,90 @@ def _artifact_to_lineage_object(self):
201201
return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
202202

203203

204+
class DashVisualizer(object):
205+
"""Create object used for visualizing graph using Dash library."""
206+
207+
def __init__(self):
208+
"""Init for DashVisualizer."""
209+
# import visualization packages
210+
self.cyto, self.JupyterDash, self.html = self._import_visual_modules()
211+
212+
def _import_visual_modules(self):
213+
"""Import modules needed for visualization."""
214+
try:
215+
import dash_cytoscape as cyto
216+
except ImportError as e:
217+
print(e)
218+
print("try pip install dash-cytoscape")
219+
220+
try:
221+
from jupyter_dash import JupyterDash
222+
except ImportError as e:
223+
print(e)
224+
print("try pip install jupyter-dash")
225+
226+
try:
227+
from dash import html
228+
except ImportError as e:
229+
print(e)
230+
print("try pip install dash")
231+
232+
return cyto, JupyterDash, html
233+
234+
def _get_app(self, elements):
235+
"""Create JupyterDash app for interactivity on Jupyter notebook."""
236+
app = self.JupyterDash(__name__)
237+
self.cyto.load_extra_layouts()
238+
239+
app.layout = self.html.Div(
240+
[
241+
self.cyto.Cytoscape(
242+
id="cytoscape-layout-1",
243+
elements=elements,
244+
style={"width": "100%", "height": "350px"},
245+
layout={"name": "klay"},
246+
stylesheet=[
247+
{
248+
"selector": "node",
249+
"style": {
250+
"label": "data(label)",
251+
"font-size": "3.5vw",
252+
"height": "10vw",
253+
"width": "10vw",
254+
},
255+
},
256+
{
257+
"selector": "edge",
258+
"style": {
259+
"label": "data(label)",
260+
"color": "gray",
261+
"text-halign": "left",
262+
"text-margin-y": "3px",
263+
"text-margin-x": "-2px",
264+
"font-size": "3%",
265+
"width": "1%",
266+
"curve-style": "taxi",
267+
"target-arrow-color": "gray",
268+
"target-arrow-shape": "triangle",
269+
"line-color": "gray",
270+
"arrow-scale": "0.5",
271+
},
272+
},
273+
],
274+
responsive=True,
275+
)
276+
]
277+
)
278+
279+
return app
280+
281+
def render(self, elements, mode):
282+
"""Render graph for lineage query result."""
283+
app = self._get_app(elements)
284+
285+
return app.run_server(mode=mode)
286+
287+
204288
class LineageQueryResult(object):
205289
"""A wrapper around the results of a lineage query."""
206290

@@ -226,29 +310,74 @@ def __init__(
226310

227311
def __str__(self):
228312
"""Define string representation of ``LineageQueryResult``.
229-
313+
230314
Format:
231315
{
232316
'edges':[
233317
{
234-
'source_arn': 'string', 'destination_arn': 'string',
318+
'source_arn': 'string', 'destination_arn': 'string',
235319
'association_type': 'string'
236320
},
237321
...
238322
]
239323
'vertices':[
240324
{
241-
'arn': 'string', 'lineage_entity': 'string',
242-
'lineage_source': 'string',
325+
'arn': 'string', 'lineage_entity': 'string',
326+
'lineage_source': 'string',
243327
'_session': <sagemaker.session.Session object>
244328
},
245329
...
246330
]
247331
}
248-
332+
249333
"""
250334
result_dict = vars(self)
251-
return (str({k: [vars(val) for val in v] for k, v in result_dict.items()}))
335+
return str({k: [vars(val) for val in v] for k, v in result_dict.items()})
336+
337+
def _covert_vertices_to_tuples(self):
338+
"""Convert vertices to tuple format for visualizer."""
339+
verts = []
340+
for vert in self.vertices:
341+
verts.append((vert.arn, vert.lineage_source))
342+
return verts
343+
344+
def _covert_edges_to_tuples(self):
345+
"""Convert edges to tuple format for visualizer."""
346+
edges = []
347+
for edge in self.edges:
348+
edges.append((edge.source_arn, edge.destination_arn, edge.association_type))
349+
return edges
350+
351+
def _get_visualization_elements(self):
352+
"""Get elements for visualization."""
353+
verts = self._covert_vertices_to_tuples()
354+
edges = self._covert_edges_to_tuples()
355+
356+
nodes = [
357+
{
358+
"data": {"id": id, "label": label},
359+
}
360+
for id, label in verts
361+
]
362+
363+
edges = [
364+
{"data": {"source": source, "target": target, "label": label}}
365+
for source, target, label in edges
366+
]
367+
368+
elements = nodes + edges
369+
370+
return elements
371+
372+
def visualize(self):
373+
"""Visualize lineage query result."""
374+
elements = self._get_visualization_elements()
375+
376+
dash_vis = DashVisualizer()
377+
378+
dash_server = dash_vis.render(elements=elements, mode="inline")
379+
380+
return dash_server
252381

253382

254383
class LineageFilter(object):

0 commit comments

Comments
 (0)