Skip to content

fix: Remove duplicate vertex/edge in query lineage #2784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/sagemaker/lineage/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# language governing permissions and limitations under the License.
"""This module contains code to query SageMaker lineage."""
from __future__ import absolute_import

from datetime import datetime
from enum import Enum
from typing import Optional, Union, List, Dict

from sagemaker.lineage._utils import get_resource_name_from_arn


Expand Down Expand Up @@ -65,6 +67,27 @@ def __init__(
self.destination_arn = destination_arn
self.association_type = association_type

def __hash__(self):
"""Define hash function for ``Edge``."""
return hash(
(
"source_arn",
self.source_arn,
"destination_arn",
self.destination_arn,
"association_type",
self.association_type,
)
)

def __eq__(self, other):
"""Define equal function for ``Edge``."""
return (
self.association_type == other.association_type
and self.source_arn == other.source_arn
and self.destination_arn == other.destination_arn
)


class Vertex:
"""A vertex for a lineage graph."""
Expand All @@ -82,6 +105,27 @@ def __init__(
self.lineage_source = lineage_source
self._session = sagemaker_session

def __hash__(self):
"""Define hash function for ``Vertex``."""
return hash(
(
"arn",
self.arn,
"lineage_entity",
self.lineage_entity,
"lineage_source",
self.lineage_source,
)
)

def __eq__(self, other):
"""Define equal function for ``Vertex``."""
return (
self.arn == other.arn
and self.lineage_entity == other.lineage_entity
and self.lineage_source == other.lineage_source
)

def to_lineage_object(self):
"""Convert the ``Vertex`` object to its corresponding Artifact, Action, Context object."""
from sagemaker.lineage.artifact import Artifact, ModelArtifact
Expand Down Expand Up @@ -210,6 +254,18 @@ def _convert_api_response(self, response) -> LineageQueryResult:
converted.edges = [self._get_edge(edge) for edge in response["Edges"]]
converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]]

edge_set = set()
for edge in converted.edges:
if edge in edge_set:
converted.edges.remove(edge)
edge_set.add(edge)

vertex_set = set()
for vertex in converted.vertices:
if vertex in vertex_set:
converted.vertices.remove(vertex)
vertex_set.add(vertex)

return converted

def _collapse_cross_account_artifacts(self, query_response):
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/sagemaker/lineage/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,38 @@ def test_lineage_query(sagemaker_session):
start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"]
)

assert len(response.edges) == 1
assert response.edges[0].source_arn == "arn1"
assert response.edges[0].destination_arn == "arn2"
assert response.edges[0].association_type == "Produced"
assert len(response.vertices) == 2

assert response.vertices[0].arn == "arn1"
assert response.vertices[0].lineage_source == "Endpoint"
assert response.vertices[0].lineage_entity == "Artifact"
assert response.vertices[1].arn == "arn2"
assert response.vertices[1].lineage_source == "Model"
assert response.vertices[1].lineage_entity == "Context"


def test_lineage_query_duplication(sagemaker_session):
lineage_query = LineageQuery(sagemaker_session)
sagemaker_session.sagemaker_client.query_lineage.return_value = {
"Vertices": [
{"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"},
{"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"},
{"Arn": "arn2", "Type": "Model", "LineageType": "Context"},
],
"Edges": [
{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"},
{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"},
],
}

response = lineage_query.query(
start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"]
)

assert len(response.edges) == 1
assert response.edges[0].source_arn == "arn1"
assert response.edges[0].destination_arn == "arn2"
Expand Down