Skip to content

Commit 5609e42

Browse files
author
Yi-Ting Lee
committed
test_get_visualization_elements added
1 parent 3bab76a commit 5609e42

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

src/sagemaker/lineage/query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def _get_visualization_elements(self):
364364
elements = {"nodes": verts, "edges": edges}
365365
return elements
366366

367-
def visualize(self):
367+
def visualize(self, path="pyvisExample.html"):
368368
"""Visualize lineage query result."""
369369
lineage_graph = {
370370
# nodes can have shape / color
@@ -398,7 +398,7 @@ def visualize(self):
398398

399399
pyvis_vis = PyvisVisualizer(lineage_graph)
400400
elements = self._get_visualization_elements()
401-
return pyvis_vis.render(elements=elements)
401+
return pyvis_vis.render(elements=elements, path=path)
402402

403403

404404
class LineageFilter(object):

tests/integ/sagemaker/lineage/helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import uuid
1717
from datetime import datetime
1818
import time
19-
19+
import boto3
20+
from botocore.config import Config
2021

2122
def name():
2223
return "lineage-integ-{}-{}".format(
@@ -78,3 +79,4 @@ def visit(arn, visited: set):
7879

7980
ret = []
8081
return visit(start_arn, set())
82+

tests/unit/sagemaker/lineage/test_query.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,31 @@ def test_vertex_to_object_unconvertable(sagemaker_session):
524524

525525
with pytest.raises(ValueError):
526526
vertex.to_lineage_object()
527+
528+
529+
def test_get_visualization_elements(sagemaker_session):
530+
lineage_query = LineageQuery(sagemaker_session)
531+
sagemaker_session.sagemaker_client.query_lineage.return_value = {
532+
"Vertices": [
533+
{"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"},
534+
{"Arn": "arn2", "Type": "Model", "LineageType": "Context"},
535+
],
536+
"Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}],
537+
}
538+
539+
query_response = lineage_query.query(
540+
start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"]
541+
)
542+
543+
print(query_response)
544+
545+
elements = query_response._get_visualization_elements()
546+
547+
print(elements)
548+
549+
assert elements["nodes"][0] == ("arn1", "Endpoint", "Artifact", False)
550+
assert elements["nodes"][1] == ("arn2", "Model", "Context", False)
551+
assert elements["edges"][0] == ("arn1", "arn2", "Produced")
552+
553+
554+

0 commit comments

Comments
 (0)