Skip to content

Commit 41d2453

Browse files
author
Yi-Ting Lee
committed
inject graph data to DashVisualizer task
1 parent 7c2b0c3 commit 41d2453

File tree

1 file changed

+46
-16
lines changed

1 file changed

+46
-16
lines changed

src/sagemaker/lineage/query.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def _artifact_to_lineage_object(self):
204204
class DashVisualizer(object):
205205
"""Create object used for visualizing graph using Dash library."""
206206

207-
def __init__(self):
207+
def __init__(self, graph_styles):
208208
"""Init for DashVisualizer."""
209209
# import visualization packages
210210
(
@@ -215,12 +215,7 @@ def __init__(self):
215215
self.Output,
216216
) = self._import_visual_modules()
217217

218-
self.entity_color = {
219-
"TrialComponent": "#f6cf61",
220-
"Context": "#ff9900",
221-
"Action": "#88c396",
222-
"Artifact": "#146eb4",
223-
}
218+
self.graph_styles = graph_styles
224219

225220
def _import_visual_modules(self):
226221
"""Import modules needed for visualization."""
@@ -254,12 +249,19 @@ def _import_visual_modules(self):
254249

255250
return cyto, JupyterDash, html, Input, Output
256251

257-
def _create_legend_component(self, text, color, colorText=""):
252+
def _create_legend_component(self, style):
258253
"""Create legend component div."""
254+
text = style["name"]
255+
symbol = ""
256+
color = "#ffffff"
257+
if style["isShape"] == "False":
258+
color = style["style"]["background-color"]
259+
else:
260+
symbol = style["symbol"]
259261
return self.html.Div(
260262
[
261263
self.html.Div(
262-
colorText,
264+
symbol,
263265
style={
264266
"background-color": color,
265267
"width": "1.5vw",
@@ -282,9 +284,9 @@ def _create_legend_component(self, text, color, colorText=""):
282284
]
283285
)
284286

285-
def _create_entity_selector(self, entity_name, color):
287+
def _create_entity_selector(self, entity_name, style):
286288
"""Create selector for each lineage entity."""
287-
return {"selector": "." + entity_name, "style": {"background-color": color}}
289+
return {"selector": "." + entity_name, "style": style["style"]}
288290

289291
def _get_app(self, elements):
290292
"""Create JupyterDash app for interactivity on Jupyter notebook."""
@@ -337,10 +339,9 @@ def _get_app(self, elements):
337339
"font-family": "verdana",
338340
},
339341
},
340-
{"selector": ".startarn", "style": {"shape": "star"}},
341342
{"selector": ".select", "style": {"border-opacity": "0.7"}},
342343
]
343-
+ [self._create_entity_selector(k, v) for k, v in self.entity_color.items()],
344+
+ [self._create_entity_selector(k, v) for k, v in self.graph_styles.items()],
344345
responsive=True,
345346
),
346347
self.html.Div(
@@ -354,8 +355,7 @@ def _get_app(self, elements):
354355
),
355356
# legend section
356357
self.html.Div(
357-
[self._create_legend_component(k, v) for k, v in self.entity_color.items()]
358-
+ [self._create_legend_component("StartArn", "#ffffff", "★")],
358+
[self._create_legend_component(v) for k, v in self.graph_styles.items()],
359359
style={
360360
"display": "inline-block",
361361
"font-size": "1vw",
@@ -492,8 +492,38 @@ def visualize(self):
492492
"""Visualize lineage query result."""
493493
elements = self._get_visualization_elements()
494494

495+
lineage_graph = {
496+
# nodes can have shape / color
497+
"TrialComponent": {
498+
"name": "Trial Component",
499+
"style": {"background-color": "#f6cf61"},
500+
"isShape": "False",
501+
},
502+
"Context": {
503+
"name": "Context",
504+
"style": {"background-color": "#ff9900"},
505+
"isShape": "False",
506+
},
507+
"Action": {
508+
"name": "Action",
509+
"style": {"background-color": "#88c396"},
510+
"isShape": "False",
511+
},
512+
"Artifact": {
513+
"name": "Artifact",
514+
"style": {"background-color": "#146eb4"},
515+
"isShape": "False",
516+
},
517+
"StartArn": {
518+
"name": "StartArn",
519+
"style": {"shape": "star"},
520+
"isShape": "True",
521+
"symbol": "★", # shape symbol for legend
522+
},
523+
}
524+
495525
# initialize DashVisualizer instance to render graph & interactive components
496-
dash_vis = DashVisualizer()
526+
dash_vis = DashVisualizer(lineage_graph)
497527

498528
dash_server = dash_vis.render(elements=elements, mode="inline")
499529

0 commit comments

Comments
 (0)