Skip to content

Commit fa81fa0

Browse files
committed
feature: Add support for SageMaker lineage queries in artifact, context and trial component
1 parent 554d735 commit fa81fa0

File tree

14 files changed

+809
-17
lines changed

14 files changed

+809
-17
lines changed

src/sagemaker/lineage/action.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Optional, Iterator
1717
from datetime import datetime
1818

19-
from sagemaker import Session
19+
from sagemaker.session import Session
2020
from sagemaker.apiutils import _base_types
2121
from sagemaker.lineage import _api_types, _utils
2222
from sagemaker.lineage._api_types import ActionSource, ActionSummary
@@ -116,7 +116,7 @@ def delete(self, disassociate: bool = False):
116116
self._invoke_api(self._boto_delete_method, self._boto_delete_members)
117117

118118
@classmethod
119-
def load(cls, action_name: str, sagemaker_session: Session = None) -> "Action":
119+
def load(cls, action_name: str, sagemaker_session=None) -> "Action":
120120
"""Load an existing action and return an ``Action`` object representing it.
121121
122122
Args:

src/sagemaker/lineage/artifact.py

Lines changed: 137 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def downstream_trials(self, sagemaker_session=None) -> list:
146146
"""Retrieve all trial runs which that use this artifact.
147147
148148
Args:
149-
sagemaker_session (obj): Sagemaker Sesssion to use. If not provided a default session
149+
sagemaker_session (obj): Sagemaker Session to use. If not provided a default session
150150
will be created.
151151
152152
Returns:
@@ -159,6 +159,57 @@ def downstream_trials(self, sagemaker_session=None) -> list:
159159
)
160160
trial_component_arns: list = list(map(lambda x: x.destination_arn, outgoing_associations))
161161

162+
return self._get_trial_from_trial_component(trial_component_arns)
163+
164+
def downstream_trials_v2(
165+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
166+
) -> list:
167+
"""Retrieve all downstream trial runs which that use this artifact by using lineage query.
168+
169+
Args:
170+
direction (LineageQueryDirectionEnum, optional): The query direction.
171+
172+
Returns:
173+
[Trial]: A list of SageMaker `Trial` objects.
174+
"""
175+
query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT])
176+
query_result = LineageQuery(self.sagemaker_session).query(
177+
start_arns=[self.artifact_arn],
178+
query_filter=query_filter,
179+
direction=direction,
180+
include_edges=False,
181+
)
182+
trial_component_arns: list = list(map(lambda x: x.arn, query_result.vertices))
183+
184+
return self._get_trial_from_trial_component(trial_component_arns)
185+
186+
def get_upstream_trials(
187+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
188+
) -> List[str]:
189+
"""Retrieve all upstream trial runs which that use this artifact by using lineage query.
190+
191+
Args:
192+
direction (LineageQueryDirectionEnum, optional): The query direction.
193+
194+
Returns:
195+
[Trial]: A list of SageMaker `Trial` objects.
196+
"""
197+
query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT])
198+
query_result = LineageQuery(self.sagemaker_session).query(
199+
start_arns=[self.artifact_arn],
200+
query_filter=query_filter,
201+
direction=direction,
202+
include_edges=False,
203+
)
204+
trial_component_arns: list = list(map(lambda x: x.arn, query_result.vertices))
205+
return self._get_trial_from_trial_component(trial_component_arns)
206+
207+
def _get_trial_from_trial_component(self, trial_component_arns: list):
208+
"""Retrieve all upstream trial runs which that use the trial component arns.
209+
210+
Returns:
211+
[Trial]: A list of SageMaker `Trial` objects.
212+
"""
162213
if not trial_component_arns:
163214
# no outgoing associations for this artifact
164215
return []
@@ -170,7 +221,7 @@ def downstream_trials(self, sagemaker_session=None) -> list:
170221
num_search_batches = math.ceil(len(trial_component_arns) % max_search_by_arn)
171222
trial_components: list = []
172223

173-
sagemaker_session = sagemaker_session or _utils.default_session()
224+
sagemaker_session = self.sagemaker_session or _utils.default_session()
174225
sagemaker_client = sagemaker_session.sagemaker_client
175226

176227
for i in range(num_search_batches):
@@ -335,6 +386,17 @@ def list(
335386
sagemaker_session=sagemaker_session,
336387
)
337388

389+
def get_artifacts(self, s3_uri: str):
390+
"""Return a list of artifact that use provided s3 uri.
391+
392+
Args:
393+
s3_uri (str): A S3 URI.
394+
395+
Returns:
396+
A list of ``Artifacts``
397+
"""
398+
return self.sagemaker_session.sagemaker_client.list_artifacts(SourceUri=s3_uri)
399+
338400

339401
class ModelArtifact(Artifact):
340402
"""A SageMaker lineage artifact representing a model.
@@ -349,7 +411,7 @@ def endpoints(self) -> list:
349411
"""Get association summaries for endpoints deployed with this model.
350412
351413
Returns:
352-
[AssociationSummary]: A list of associations repesenting the endpoints using the model.
414+
[AssociationSummary]: A list of associations representing the endpoints using the model.
353415
"""
354416
endpoint_development_actions: Iterator = Association.list(
355417
source_arn=self.artifact_arn,
@@ -522,3 +584,75 @@ def endpoint_contexts(
522584
for vertex in query_result.vertices:
523585
endpoint_contexts.append(vertex.to_lineage_object())
524586
return endpoint_contexts
587+
588+
def get_upstream_datasets(
589+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
590+
) -> List[Artifact]:
591+
"""Get upstream artifacts representing dataset from the dataset's lineage.
592+
593+
Args:
594+
direction (LineageQueryDirectionEnum, optional): The query direction.
595+
596+
Returns:
597+
list of Artifacts: Artifacts representing an dataset.
598+
"""
599+
query_filter = LineageFilter(
600+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
601+
)
602+
query_result = LineageQuery(self.sagemaker_session).query(
603+
start_arns=[self.artifact_arn],
604+
query_filter=query_filter,
605+
direction=direction,
606+
include_edges=False,
607+
)
608+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
609+
610+
def get_downstream_datasets(
611+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
612+
) -> List[Artifact]:
613+
"""Get downstream artifacts representing dataset from the dataset's lineage.
614+
615+
Args:
616+
direction (LineageQueryDirectionEnum, optional): The query direction.
617+
618+
Returns:
619+
list of Artifacts: Artifacts representing an dataset.
620+
"""
621+
query_filter = LineageFilter(
622+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
623+
)
624+
query_result = LineageQuery(self.sagemaker_session).query(
625+
start_arns=[self.artifact_arn],
626+
query_filter=query_filter,
627+
direction=direction,
628+
include_edges=False,
629+
)
630+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
631+
632+
633+
class ImageArtifact(Artifact):
634+
"""A SageMaker lineage artifact representing an image.
635+
636+
Common model specific lineage traversals to discover how the image is connected
637+
to other entities.
638+
"""
639+
640+
def get_dataset(self, direction: LineageQueryDirectionEnum) -> List[Artifact]:
641+
"""Get artifacts representing dataset from the image artifact's lineage.
642+
643+
Args:
644+
direction (LineageQueryDirectionEnum): The query direction.
645+
646+
Returns:
647+
list of Artifacts: Artifacts representing an dataset.
648+
"""
649+
query_filter = LineageFilter(
650+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
651+
)
652+
query_result = LineageQuery(self.sagemaker_session).query(
653+
start_arns=[self.artifact_arn],
654+
query_filter=query_filter,
655+
direction=direction,
656+
include_edges=False,
657+
)
658+
return [vertex.to_lineage_object() for vertex in query_result.vertices]

src/sagemaker/lineage/context.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
LineageQueryDirectionEnum,
3232
)
3333
from sagemaker.lineage.artifact import Artifact
34+
from sagemaker.lineage.action import Action
35+
from sagemaker.lineage.lineage_trial_component import LineageTrialComponent
3436

3537

3638
class Context(_base_types.Record):
@@ -256,6 +258,24 @@ def list(
256258
sagemaker_session=sagemaker_session,
257259
)
258260

261+
def actions(self, direction: LineageQueryDirectionEnum) -> List[Action]:
262+
"""Get actions from the endpoint's lineage.
263+
264+
Args:
265+
direction (LineageQueryDirectionEnum): The query direction.
266+
267+
Returns:
268+
list of Actions: Actions.
269+
"""
270+
query_filter = LineageFilter(entities=[LineageEntityEnum.ACTION])
271+
query_result = LineageQuery(self.sagemaker_session).query(
272+
start_arns=[self.context_arn],
273+
query_filter=query_filter,
274+
direction=direction,
275+
include_edges=False,
276+
)
277+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
278+
259279

260280
class EndpointContext(Context):
261281
"""An Amazon SageMaker endpoint context, which is part of a SageMaker lineage."""
@@ -360,6 +380,9 @@ def training_job_arns(
360380
) -> List[str]:
361381
"""Get ARNs for all training jobs that appear in the endpoint's lineage.
362382
383+
Args:
384+
direction (LineageQueryDirectionEnum, optional): The query direction.
385+
363386
Returns:
364387
list of str: Training job ARNs.
365388
"""
@@ -382,11 +405,86 @@ def training_job_arns(
382405
training_job_arns.append(trial_component["Source"]["SourceArn"])
383406
return training_job_arns
384407

408+
def processing_job_arns(
409+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
410+
) -> List[str]:
411+
"""Get ARNs for all processing jobs that appear in the endpoint's lineage.
412+
413+
Args:
414+
direction (LineageQueryDirectionEnum, optional): The query direction.
415+
416+
Returns:
417+
list of str: Processing job ARNs.
418+
"""
419+
query_filter = LineageFilter(
420+
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.PROCESSING_JOB]
421+
)
422+
query_result = LineageQuery(self.sagemaker_session).query(
423+
start_arns=[self.context_arn],
424+
query_filter=query_filter,
425+
direction=direction,
426+
include_edges=False,
427+
)
428+
429+
processing_job_arns = []
430+
for vertex in query_result.vertices:
431+
trial_component_name = _utils.get_resource_name_from_arn(vertex.arn)
432+
trial_component = self.sagemaker_session.sagemaker_client.describe_trial_component(
433+
TrialComponentName=trial_component_name
434+
)
435+
processing_job_arns.append(trial_component["Source"]["SourceArn"])
436+
return processing_job_arns
437+
438+
def transform_jobs(
439+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
440+
) -> List[LineageTrialComponent]:
441+
"""Get ARNs for all transform jobs that appear in the endpoint's lineage.
442+
443+
Args:
444+
direction (LineageQueryDirectionEnum, optional): The query direction.
445+
446+
Returns:
447+
list of str: Transform jobs.
448+
"""
449+
query_filter = LineageFilter(
450+
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRANSFORM_JOB]
451+
)
452+
query_result = LineageQuery(self.sagemaker_session).query(
453+
start_arns=[self.context_arn],
454+
query_filter=query_filter,
455+
direction=direction,
456+
include_edges=False,
457+
)
458+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
459+
460+
def trial_components(
461+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
462+
) -> List[LineageTrialComponent]:
463+
"""Get ARNs for all trial components that appear in the endpoint's lineage.
464+
465+
Args:
466+
direction (LineageQueryDirectionEnum, optional): The query direction.
467+
468+
Returns:
469+
list of str: Trial components ARNs.
470+
"""
471+
query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT])
472+
query_result = LineageQuery(self.sagemaker_session).query(
473+
start_arns=[self.context_arn],
474+
query_filter=query_filter,
475+
direction=direction,
476+
include_edges=False,
477+
)
478+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
479+
385480
def pipeline_execution_arn(
386481
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
387482
) -> str:
388483
"""Get the ARN for the pipeline execution associated with this endpoint (if any).
389484
485+
Args:
486+
direction (LineageQueryDirectionEnum, optional): The query direction.
487+
390488
Returns:
391489
str: A pipeline execution ARN.
392490
"""

0 commit comments

Comments
 (0)