Skip to content

Commit 7359b3d

Browse files
author
Yi-Ting Lee
committed
feature: query lineage visualizer for general case
edge.association_type added style changes of graph
1 parent e0c59c2 commit 7359b3d

File tree

2 files changed

+216
-13
lines changed

2 files changed

+216
-13
lines changed

src/sagemaker/lineage/query.py

Lines changed: 106 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,12 @@ def __str__(self):
9797
9898
Format:
9999
{
100-
'source_arn': 'string', 'destination_arn': 'string',
100+
'source_arn': 'string', 'destination_arn': 'string',
101101
'association_type': 'string'
102102
}
103-
103+
104104
"""
105-
return (str(self.__dict__))
105+
return str(self.__dict__)
106106

107107

108108
class Vertex:
@@ -147,13 +147,13 @@ def __str__(self):
147147
148148
Format:
149149
{
150-
'arn': 'string', 'lineage_entity': 'string',
151-
'lineage_source': 'string',
150+
'arn': 'string', 'lineage_entity': 'string',
151+
'lineage_source': 'string',
152152
'_session': <sagemaker.session.Session object>
153153
}
154-
154+
155155
"""
156-
return (str(self.__dict__))
156+
return str(self.__dict__)
157157

158158
def to_lineage_object(self):
159159
"""Convert the ``Vertex`` object to its corresponding lineage object.
@@ -226,29 +226,122 @@ def __init__(
226226

227227
def __str__(self):
228228
"""Define string representation of ``LineageQueryResult``.
229-
229+
230230
Format:
231231
{
232232
'edges':[
233233
{
234-
'source_arn': 'string', 'destination_arn': 'string',
234+
'source_arn': 'string', 'destination_arn': 'string',
235235
'association_type': 'string'
236236
},
237237
...
238238
]
239239
'vertices':[
240240
{
241-
'arn': 'string', 'lineage_entity': 'string',
242-
'lineage_source': 'string',
241+
'arn': 'string', 'lineage_entity': 'string',
242+
'lineage_source': 'string',
243243
'_session': <sagemaker.session.Session object>
244244
},
245245
...
246246
]
247247
}
248-
248+
249249
"""
250250
result_dict = vars(self)
251-
return (str({k: [vars(val) for val in v] for k, v in result_dict.items()}))
251+
return str({k: [vars(val) for val in v] for k, v in result_dict.items()})
252+
253+
def _import_visual_modules(self):
254+
"""Import modules needed for visualization."""
255+
import dash_cytoscape as cyto
256+
257+
from jupyter_dash import JupyterDash
258+
259+
from dash import html
260+
261+
return cyto, JupyterDash, html
262+
263+
def _get_verts(self):
264+
"""Convert vertices to tuple format for visualizer"""
265+
verts = []
266+
for vert in self.vertices:
267+
verts.append((vert.arn, vert.lineage_source))
268+
return verts
269+
270+
def _get_edges(self):
271+
"""Convert edges to tuple format for visualizer"""
272+
edges = []
273+
for edge in self.edges:
274+
edges.append((edge.source_arn, edge.destination_arn, edge.association_type))
275+
return edges
276+
277+
def visualize(self):
278+
"""Visualize lineage query result."""
279+
280+
cyto, JupyterDash, html = self._import_visual_modules()
281+
282+
cyto.load_extra_layouts() # load "klay" layout (hierarchical layout) from extra layouts
283+
app = JupyterDash(__name__)
284+
285+
verts = self._get_verts()
286+
edges = self._get_edges()
287+
288+
nodes = [
289+
{
290+
"data": {"id": id, "label": label},
291+
}
292+
for id, label in verts
293+
]
294+
295+
edges = [
296+
{
297+
"data": {"source": source, "target": target, "label": label}
298+
}
299+
for source, target, label in edges
300+
]
301+
302+
elements = nodes + edges
303+
304+
app.layout = html.Div(
305+
[
306+
cyto.Cytoscape(
307+
id="cytoscape-layout-1",
308+
elements=elements,
309+
style={"width": "100%", "height": "350px"},
310+
layout={"name": "klay"},
311+
stylesheet=[
312+
{
313+
"selector": "node",
314+
"style": {
315+
"label": "data(label)",
316+
"font-size": "3.5vw",
317+
"height": "10vw",
318+
"width": "10vw"
319+
}
320+
},
321+
{
322+
"selector": "edge",
323+
"style": {
324+
"label": "data(label)",
325+
"color": "gray",
326+
"text-halign": "left",
327+
"text-margin-y": "3px",
328+
"text-margin-x": "-2px",
329+
"font-size": "3%",
330+
"width": "1%",
331+
"curve-style": "taxi",
332+
"target-arrow-color": "gray",
333+
"target-arrow-shape": "triangle",
334+
"line-color": "gray",
335+
"arrow-scale": "0.5"
336+
},
337+
},
338+
],
339+
responsive=True,
340+
)
341+
]
342+
)
343+
344+
return app.run_server(mode="inline")
252345

253346

254347
class LineageFilter(object):

tests/data/_repack_model.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Repack model script for training jobs to inject entry points"""
14+
from __future__ import absolute_import
15+
16+
import argparse
17+
import os
18+
import shutil
19+
import tarfile
20+
import tempfile
21+
22+
# Repack Model
23+
# The following script is run via a training job which takes an existing model and a custom
24+
# entry point script as arguments. The script creates a new model archive with the custom
25+
# entry point in the "code" directory along with the existing model. Subsequently, when the model
26+
# is unpacked for inference, the custom entry point will be used.
27+
# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html
28+
29+
# distutils.dir_util.copy_tree works way better than the half-baked
30+
# shutil.copytree which bombs on previously existing target dirs...
31+
# alas ... https://bugs.python.org/issue10948
32+
# we'll go ahead and use the copy_tree function anyways because this
33+
# repacking is some short-lived hackery, right??
34+
from distutils.dir_util import copy_tree
35+
36+
37+
def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover
38+
"""Repack custom dependencies and code into an existing model TAR archive
39+
40+
Args:
41+
inference_script (str): The path to the custom entry point.
42+
model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive.
43+
dependencies (str): A space-delimited string of paths to custom dependencies.
44+
source_dir (str): The path to a custom source directory.
45+
"""
46+
47+
# the data directory contains a model archive generated by a previous training job
48+
data_directory = "/opt/ml/input/data/training"
49+
model_path = os.path.join(data_directory, model_archive.split("/")[-1])
50+
51+
# create a temporary directory
52+
with tempfile.TemporaryDirectory() as tmp:
53+
local_path = os.path.join(tmp, "local.tar.gz")
54+
# copy the previous training job's model archive to the temporary directory
55+
shutil.copy2(model_path, local_path)
56+
src_dir = os.path.join(tmp, "src")
57+
# create the "code" directory which will contain the inference script
58+
code_dir = os.path.join(src_dir, "code")
59+
os.makedirs(code_dir)
60+
# extract the contents of the previous training job's model archive to the "src"
61+
# directory of this training job
62+
with tarfile.open(name=local_path, mode="r:gz") as tf:
63+
tf.extractall(path=src_dir)
64+
65+
if source_dir:
66+
# copy /opt/ml/code to code/
67+
if os.path.exists(code_dir):
68+
shutil.rmtree(code_dir)
69+
shutil.copytree("/opt/ml/code", code_dir)
70+
else:
71+
# copy the custom inference script to code/
72+
entry_point = os.path.join("/opt/ml/code", inference_script)
73+
shutil.copy2(entry_point, os.path.join(code_dir, inference_script))
74+
75+
# copy any dependencies to code/lib/
76+
if dependencies:
77+
for dependency in dependencies.split(" "):
78+
actual_dependency_path = os.path.join("/opt/ml/code", dependency)
79+
lib_dir = os.path.join(code_dir, "lib")
80+
if not os.path.exists(lib_dir):
81+
os.mkdir(lib_dir)
82+
if os.path.isfile(actual_dependency_path):
83+
shutil.copy2(actual_dependency_path, lib_dir)
84+
else:
85+
if os.path.exists(lib_dir):
86+
shutil.rmtree(lib_dir)
87+
# a directory is in the dependencies. we have to copy
88+
# all of /opt/ml/code into the lib dir because the original directory
89+
# was flattened by the SDK training job upload..
90+
shutil.copytree("/opt/ml/code", lib_dir)
91+
break
92+
93+
# copy the "src" dir, which includes the previous training job's model and the
94+
# custom inference script, to the output of this training job
95+
copy_tree(src_dir, "/opt/ml/model")
96+
97+
98+
if __name__ == "__main__": # pragma: no cover
99+
parser = argparse.ArgumentParser()
100+
parser.add_argument("--inference_script", type=str, default="inference.py")
101+
parser.add_argument("--dependencies", type=str, default=None)
102+
parser.add_argument("--source_dir", type=str, default=None)
103+
parser.add_argument("--model_archive", type=str, default="model.tar.gz")
104+
args, extra = parser.parse_known_args()
105+
repack(
106+
inference_script=args.inference_script,
107+
dependencies=args.dependencies,
108+
source_dir=args.source_dir,
109+
model_archive=args.model_archive,
110+
)

0 commit comments

Comments
 (0)