Skip to content

Commit 79b8366

Browse files
committed
remove duplicate vertex/edge in query lineage
1 parent d970bfb commit 79b8366

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

src/sagemaker/lineage/query.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,22 @@ def __init__(
6565
self.destination_arn = destination_arn
6666
self.association_type = association_type
6767

68+
def __hash__(self):
69+
return hash(
70+
(
71+
"source_arn", self.source_arn,
72+
"destination_arn", self.destination_arn,
73+
"association_type", self.association_type,
74+
)
75+
)
76+
77+
def __eq__(self, other):
78+
return (
79+
self.association_type == other.association_type
80+
and self.source_arn == other.source_arn
81+
and self.destination_arn == other.destination_arn
82+
)
83+
6884

6985
class Vertex:
7086
"""A vertex for a lineage graph."""
@@ -82,6 +98,22 @@ def __init__(
8298
self.lineage_source = lineage_source
8399
self._session = sagemaker_session
84100

101+
def __hash__(self):
102+
return hash(
103+
(
104+
"arn", self.arn,
105+
"lineage_entity", self.lineage_entity,
106+
"lineage_source", self.lineage_source,
107+
)
108+
)
109+
110+
def __eq__(self, other):
111+
return (
112+
self.arn == other.arn
113+
and self.lineage_entity == other.lineage_entity
114+
and self.lineage_source == other.lineage_source
115+
)
116+
85117
def to_lineage_object(self):
86118
"""Convert the ``Vertex`` object to its corresponding ``Artifact`` or ``Context`` object."""
87119
from sagemaker.lineage.artifact import Artifact, ModelArtifact
@@ -206,6 +238,18 @@ def _convert_api_response(self, response) -> LineageQueryResult:
206238
converted.edges = [self._get_edge(edge) for edge in response["Edges"]]
207239
converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]]
208240

241+
edge_set = set()
242+
for edge in converted.edges:
243+
if edge in edge_set:
244+
converted.edges.remove(edge)
245+
edge_set.add(edge)
246+
247+
vertex_set = set()
248+
for vertex in converted.vertices:
249+
if vertex in vertex_set:
250+
converted.vertices.remove(vertex)
251+
vertex_set.add(vertex)
252+
209253
return converted
210254

211255
def query(

tests/unit/sagemaker/lineage/test_query.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,38 @@ def test_lineage_query(sagemaker_session):
3131
start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"]
3232
)
3333

34+
assert len(response.edges) == 1
35+
assert response.edges[0].source_arn == "arn1"
36+
assert response.edges[0].destination_arn == "arn2"
37+
assert response.edges[0].association_type == "Produced"
38+
assert len(response.vertices) == 2
39+
40+
assert response.vertices[0].arn == "arn1"
41+
assert response.vertices[0].lineage_source == "Endpoint"
42+
assert response.vertices[0].lineage_entity == "Artifact"
43+
assert response.vertices[1].arn == "arn2"
44+
assert response.vertices[1].lineage_source == "Model"
45+
assert response.vertices[1].lineage_entity == "Context"
46+
47+
48+
def test_lineage_query_duplication(sagemaker_session):
49+
lineage_query = LineageQuery(sagemaker_session)
50+
sagemaker_session.sagemaker_client.query_lineage.return_value = {
51+
"Vertices": [
52+
{"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"},
53+
{"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"},
54+
{"Arn": "arn2", "Type": "Model", "LineageType": "Context"},
55+
],
56+
"Edges": [
57+
{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"},
58+
{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"},
59+
],
60+
}
61+
62+
response = lineage_query.query(
63+
start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"]
64+
)
65+
3466
assert len(response.edges) == 1
3567
assert response.edges[0].source_arn == "arn1"
3668
assert response.edges[0].destination_arn == "arn2"

0 commit comments

Comments
 (0)