Skip to content

feature: Adds Lineage queries in artifact, context and trial components #2838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/sagemaker/lineage/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
from typing import Optional, Iterator, List
from datetime import datetime

from sagemaker import Session
from sagemaker.session import Session
from sagemaker.apiutils import _base_types
from sagemaker.lineage import _api_types, _utils
from sagemaker.lineage._api_types import ActionSource, ActionSummary
from sagemaker.lineage.artifact import Artifact
from sagemaker.lineage.context import Context

from sagemaker.lineage.query import (
LineageQuery,
Expand Down Expand Up @@ -126,7 +125,7 @@ def delete(self, disassociate: bool = False):
self._invoke_api(self._boto_delete_method, self._boto_delete_members)

@classmethod
def load(cls, action_name: str, sagemaker_session: Session = None) -> "Action":
def load(cls, action_name: str, sagemaker_session=None) -> "Action":
"""Load an existing action and return an ``Action`` object representing it.

Args:
Expand Down Expand Up @@ -324,7 +323,7 @@ def model_package(self):

def endpoints(
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
) -> List[Context]:
) -> List:
"""Use a lineage query to retrieve downstream endpoint contexts that use this action.

Args:
Expand Down
133 changes: 129 additions & 4 deletions src/sagemaker/lineage/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ def load(cls, artifact_arn: str, sagemaker_session=None) -> "Artifact":
return artifact

def downstream_trials(self, sagemaker_session=None) -> list:
"""Retrieve all trial runs which that use this artifact.
"""Use the lineage API to retrieve all downstream trials that use this artifact.

Args:
sagemaker_session (obj): Sagemaker Sesssion to use. If not provided a default session
sagemaker_session (obj): Sagemaker Session to use. If not provided a default session
will be created.

Returns:
Expand All @@ -159,6 +159,54 @@ def downstream_trials(self, sagemaker_session=None) -> list:
)
trial_component_arns: list = list(map(lambda x: x.destination_arn, outgoing_associations))

return self._get_trial_from_trial_component(trial_component_arns)

def downstream_trials_v2(self) -> list:
"""Use a lineage query to retrieve all downstream trials that use this artifact.

Returns:
[Trial]: A list of SageMaker `Trial` objects.
"""
return self._trials(direction=LineageQueryDirectionEnum.DESCENDANTS)

def upstream_trials(self) -> List:
"""Use the lineage query to retrieve all upstream trials that use this artifact.

Returns:
[Trial]: A list of SageMaker `Trial` objects.
"""
return self._trials(direction=LineageQueryDirectionEnum.ASCENDANTS)

def _trials(
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH
) -> List:
"""Use the lineage query to retrieve all trials that use this artifact.

Args:
direction (LineageQueryDirectionEnum, optional): The query direction.

Returns:
[Trial]: A list of SageMaker `Trial` objects.
"""
query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT])
query_result = LineageQuery(self.sagemaker_session).query(
start_arns=[self.artifact_arn],
query_filter=query_filter,
direction=direction,
include_edges=False,
)
trial_component_arns: list = list(map(lambda x: x.arn, query_result.vertices))
return self._get_trial_from_trial_component(trial_component_arns)

def _get_trial_from_trial_component(self, trial_component_arns: list) -> List:
"""Retrieve all upstream trial runs which that use the trial component arns.

Args:
trial_component_arns (list): list of trial component arns

Returns:
[Trial]: A list of SageMaker `Trial` objects.
"""
if not trial_component_arns:
# no outgoing associations for this artifact
return []
Expand All @@ -170,7 +218,7 @@ def downstream_trials(self, sagemaker_session=None) -> list:
num_search_batches = math.ceil(len(trial_component_arns) % max_search_by_arn)
trial_components: list = []

sagemaker_session = sagemaker_session or _utils.default_session()
sagemaker_session = self.sagemaker_session or _utils.default_session()
sagemaker_client = sagemaker_session.sagemaker_client

for i in range(num_search_batches):
Expand Down Expand Up @@ -335,6 +383,17 @@ def list(
sagemaker_session=sagemaker_session,
)

def s3_uri_artifacts(self, s3_uri: str) -> dict:
"""Retrieve a list of artifacts that use provided s3 uri.

Args:
s3_uri (str): A S3 URI.

Returns:
A list of ``Artifacts``
"""
return self.sagemaker_session.sagemaker_client.list_artifacts(SourceUri=s3_uri)


class ModelArtifact(Artifact):
"""A SageMaker lineage artifact representing a model.
Expand All @@ -349,7 +408,7 @@ def endpoints(self) -> list:
"""Get association summaries for endpoints deployed with this model.

Returns:
[AssociationSummary]: A list of associations repesenting the endpoints using the model.
[AssociationSummary]: A list of associations representing the endpoints using the model.
"""
endpoint_development_actions: Iterator = Association.list(
source_arn=self.artifact_arn,
Expand Down Expand Up @@ -522,3 +581,69 @@ def endpoint_contexts(
for vertex in query_result.vertices:
endpoint_contexts.append(vertex.to_lineage_object())
return endpoint_contexts

def upstream_datasets(self) -> List[Artifact]:
"""Use the lineage query to retrieve upstream artifacts that use this dataset artifact.

Returns:
list of Artifacts: Artifacts representing an dataset.
"""
return self._datasets(direction=LineageQueryDirectionEnum.ASCENDANTS)

def downstream_datasets(self) -> List[Artifact]:
"""Use the lineage query to retrieve downstream artifacts that use this dataset.

Returns:
list of Artifacts: Artifacts representing an dataset.
"""
return self._datasets(direction=LineageQueryDirectionEnum.DESCENDANTS)

def _datasets(
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH
) -> List[Artifact]:
"""Use the lineage query to retrieve all artifacts that use this dataset.

Args:
direction (LineageQueryDirectionEnum, optional): The query direction.

Returns:
list of Artifacts: Artifacts representing an dataset.
"""
query_filter = LineageFilter(
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
)
query_result = LineageQuery(self.sagemaker_session).query(
start_arns=[self.artifact_arn],
query_filter=query_filter,
direction=direction,
include_edges=False,
)
return [vertex.to_lineage_object() for vertex in query_result.vertices]


class ImageArtifact(Artifact):
"""A SageMaker lineage artifact representing an image.

Common model specific lineage traversals to discover how the image is connected
to other entities.
"""

def datasets(self, direction: LineageQueryDirectionEnum) -> List[Artifact]:
"""Use the lineage query to retrieve datasets that use this image artifact.

Args:
direction (LineageQueryDirectionEnum): The query direction.

Returns:
list of Artifacts: Artifacts representing a dataset.
"""
query_filter = LineageFilter(
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
)
query_result = LineageQuery(self.sagemaker_session).query(
start_arns=[self.artifact_arn],
query_filter=query_filter,
direction=direction,
include_edges=False,
)
return [vertex.to_lineage_object() for vertex in query_result.vertices]
96 changes: 93 additions & 3 deletions src/sagemaker/lineage/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
LineageQueryDirectionEnum,
)
from sagemaker.lineage.artifact import Artifact
from sagemaker.lineage.action import Action
from sagemaker.lineage.lineage_trial_component import LineageTrialComponent


class Context(_base_types.Record):
Expand Down Expand Up @@ -256,12 +258,30 @@ def list(
sagemaker_session=sagemaker_session,
)

def actions(self, direction: LineageQueryDirectionEnum) -> List[Action]:
"""Use the lineage query to retrieve actions that use this context.

Args:
direction (LineageQueryDirectionEnum): The query direction.

Returns:
list of Actions: Actions.
"""
query_filter = LineageFilter(entities=[LineageEntityEnum.ACTION])
query_result = LineageQuery(self.sagemaker_session).query(
start_arns=[self.context_arn],
query_filter=query_filter,
direction=direction,
include_edges=False,
)
return [vertex.to_lineage_object() for vertex in query_result.vertices]


class EndpointContext(Context):
"""An Amazon SageMaker endpoint context, which is part of a SageMaker lineage."""

def models(self) -> List[association.Association]:
"""Get all models deployed by all endpoint versions of the endpoint.
"""Use Lineage API to get all models deployed by this endpoint.

Returns:
list of Associations: Associations that destination represents an endpoint's model.
Expand All @@ -286,7 +306,7 @@ def models(self) -> List[association.Association]:
def models_v2(
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
) -> List[Artifact]:
"""Get artifacts representing models from the context lineage by querying lineage data.
"""Use the lineage query to retrieve downstream model artifacts that use this endpoint.

Args:
direction (LineageQueryDirectionEnum, optional): The query direction.
Expand Down Expand Up @@ -335,7 +355,7 @@ def models_v2(
def dataset_artifacts(
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
) -> List[Artifact]:
"""Get artifacts representing datasets from the endpoint's lineage.
"""Use the lineage query to retrieve datasets that use this endpoint.

Args:
direction (LineageQueryDirectionEnum, optional): The query direction.
Expand All @@ -360,6 +380,9 @@ def training_job_arns(
) -> List[str]:
"""Get ARNs for all training jobs that appear in the endpoint's lineage.

Args:
direction (LineageQueryDirectionEnum, optional): The query direction.

Returns:
list of str: Training job ARNs.
"""
Expand All @@ -382,11 +405,78 @@ def training_job_arns(
training_job_arns.append(trial_component["Source"]["SourceArn"])
return training_job_arns

def processing_jobs(
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
) -> List[LineageTrialComponent]:
"""Use the lineage query to retrieve processing jobs that use this endpoint.

Args:
direction (LineageQueryDirectionEnum, optional): The query direction.

Returns:
list of LineageTrialComponent: Lineage trial component that represent Processing jobs.
"""
query_filter = LineageFilter(
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.PROCESSING_JOB]
)
query_result = LineageQuery(self.sagemaker_session).query(
start_arns=[self.context_arn],
query_filter=query_filter,
direction=direction,
include_edges=False,
)
return [vertex.to_lineage_object() for vertex in query_result.vertices]

def transform_jobs(
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
) -> List[LineageTrialComponent]:
"""Use the lineage query to retrieve transform jobs that use this endpoint.

Args:
direction (LineageQueryDirectionEnum, optional): The query direction.

Returns:
list of LineageTrialComponent: Lineage trial component that represent Transform jobs.
"""
query_filter = LineageFilter(
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRANSFORM_JOB]
)
query_result = LineageQuery(self.sagemaker_session).query(
start_arns=[self.context_arn],
query_filter=query_filter,
direction=direction,
include_edges=False,
)
return [vertex.to_lineage_object() for vertex in query_result.vertices]

def trial_components(
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
) -> List[LineageTrialComponent]:
"""Use the lineage query to retrieve trial components that use this endpoint.

Args:
direction (LineageQueryDirectionEnum, optional): The query direction.

Returns:
list of LineageTrialComponent: Lineage trial component.
"""
query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT])
query_result = LineageQuery(self.sagemaker_session).query(
start_arns=[self.context_arn],
query_filter=query_filter,
direction=direction,
include_edges=False,
)
return [vertex.to_lineage_object() for vertex in query_result.vertices]

def pipeline_execution_arn(
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
) -> str:
"""Get the ARN for the pipeline execution associated with this endpoint (if any).

Args:
direction (LineageQueryDirectionEnum, optional): The query direction.

Returns:
str: A pipeline execution ARN.
"""
Expand Down
Loading