Skip to content

Commit a9471df

Browse files
committed
feature: Add support for SageMaker lineage queries in artifact, context and trial component
1 parent 948debc commit a9471df

File tree

14 files changed

+812
-17
lines changed

14 files changed

+812
-17
lines changed

src/sagemaker/lineage/action.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Optional, Iterator
1717
from datetime import datetime
1818

19-
from sagemaker import Session
19+
from sagemaker.session import Session
2020
from sagemaker.apiutils import _base_types
2121
from sagemaker.lineage import _api_types, _utils
2222
from sagemaker.lineage._api_types import ActionSource, ActionSummary
@@ -116,7 +116,7 @@ def delete(self, disassociate: bool = False):
116116
self._invoke_api(self._boto_delete_method, self._boto_delete_members)
117117

118118
@classmethod
119-
def load(cls, action_name: str, sagemaker_session: Session = None) -> "Action":
119+
def load(cls, action_name: str, sagemaker_session=None) -> "Action":
120120
"""Load an existing action and return an ``Action`` object representing it.
121121
122122
Args:

src/sagemaker/lineage/artifact.py

Lines changed: 137 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def downstream_trials(self, sagemaker_session=None) -> list:
146146
"""Retrieve all trial runs which that use this artifact.
147147
148148
Args:
149-
sagemaker_session (obj): Sagemaker Sesssion to use. If not provided a default session
149+
sagemaker_session (obj): Sagemaker Session to use. If not provided a default session
150150
will be created.
151151
152152
Returns:
@@ -159,6 +159,57 @@ def downstream_trials(self, sagemaker_session=None) -> list:
159159
)
160160
trial_component_arns: list = list(map(lambda x: x.destination_arn, outgoing_associations))
161161

162+
return self._get_trial_from_trial_component(trial_component_arns)
163+
164+
def downstream_trials_v2(
165+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
166+
) -> list:
167+
"""Retrieve all downstream trial runs which that use this artifact by using lineage query.
168+
169+
Args:
170+
direction (LineageQueryDirectionEnum, optional): The query direction.
171+
172+
Returns:
173+
[Trial]: A list of SageMaker `Trial` objects.
174+
"""
175+
query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT])
176+
query_result = LineageQuery(self.sagemaker_session).query(
177+
start_arns=[self.artifact_arn],
178+
query_filter=query_filter,
179+
direction=direction,
180+
include_edges=False,
181+
)
182+
trial_component_arns: list = list(map(lambda x: x.arn, query_result.vertices))
183+
184+
return self._get_trial_from_trial_component(trial_component_arns)
185+
186+
def get_upstream_trials(
187+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
188+
) -> List[str]:
189+
"""Retrieve all upstream trial runs which that use this artifact by using lineage query.
190+
191+
Args:
192+
direction (LineageQueryDirectionEnum, optional): The query direction.
193+
194+
Returns:
195+
[Trial]: A list of SageMaker `Trial` objects.
196+
"""
197+
query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT])
198+
query_result = LineageQuery(self.sagemaker_session).query(
199+
start_arns=[self.artifact_arn],
200+
query_filter=query_filter,
201+
direction=direction,
202+
include_edges=False,
203+
)
204+
trial_component_arns: list = list(map(lambda x: x.arn, query_result.vertices))
205+
return self._get_trial_from_trial_component(trial_component_arns)
206+
207+
def _get_trial_from_trial_component(self, trial_component_arns: list):
208+
"""Retrieve all upstream trial runs which that use the trial component arns.
209+
210+
Returns:
211+
[Trial]: A list of SageMaker `Trial` objects.
212+
"""
162213
if not trial_component_arns:
163214
# no outgoing associations for this artifact
164215
return []
@@ -170,7 +221,7 @@ def downstream_trials(self, sagemaker_session=None) -> list:
170221
num_search_batches = math.ceil(len(trial_component_arns) % max_search_by_arn)
171222
trial_components: list = []
172223

173-
sagemaker_session = sagemaker_session or _utils.default_session()
224+
sagemaker_session = self.sagemaker_session or _utils.default_session()
174225
sagemaker_client = sagemaker_session.sagemaker_client
175226

176227
for i in range(num_search_batches):
@@ -335,6 +386,17 @@ def list(
335386
sagemaker_session=sagemaker_session,
336387
)
337388

389+
def get_artifacts(self, s3_uri: str):
390+
"""Return a list of artifact that use provided s3 uri.
391+
392+
Args:
393+
s3_uri (str): A S3 URI.
394+
395+
Returns:
396+
A list of ``Artifacts``
397+
"""
398+
return self.sagemaker_session.sagemaker_client.list_artifacts(SourceUri=s3_uri)
399+
338400

339401
class ModelArtifact(Artifact):
340402
"""A SageMaker lineage artifact representing a model.
@@ -349,7 +411,7 @@ def endpoints(self) -> list:
349411
"""Get association summaries for endpoints deployed with this model.
350412
351413
Returns:
352-
[AssociationSummary]: A list of associations repesenting the endpoints using the model.
414+
[AssociationSummary]: A list of associations representing the endpoints using the model.
353415
"""
354416
endpoint_development_actions: Iterator = Association.list(
355417
source_arn=self.artifact_arn,
@@ -522,3 +584,75 @@ def endpoint_contexts(
522584
for vertex in query_result.vertices:
523585
endpoint_contexts.append(vertex.to_lineage_object())
524586
return endpoint_contexts
587+
588+
def get_upstream_datasets(
589+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
590+
) -> List[Artifact]:
591+
"""Get upstream artifacts representing dataset from the dataset's lineage.
592+
593+
Args:
594+
direction (LineageQueryDirectionEnum, optional): The query direction.
595+
596+
Returns:
597+
list of Artifacts: Artifacts representing an dataset.
598+
"""
599+
query_filter = LineageFilter(
600+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
601+
)
602+
query_result = LineageQuery(self.sagemaker_session).query(
603+
start_arns=[self.artifact_arn],
604+
query_filter=query_filter,
605+
direction=direction,
606+
include_edges=False,
607+
)
608+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
609+
610+
def get_downstream_datasets(
611+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
612+
) -> List[Artifact]:
613+
"""Get downstream artifacts representing dataset from the dataset's lineage.
614+
615+
Args:
616+
direction (LineageQueryDirectionEnum, optional): The query direction.
617+
618+
Returns:
619+
list of Artifacts: Artifacts representing an dataset.
620+
"""
621+
query_filter = LineageFilter(
622+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
623+
)
624+
query_result = LineageQuery(self.sagemaker_session).query(
625+
start_arns=[self.artifact_arn],
626+
query_filter=query_filter,
627+
direction=direction,
628+
include_edges=False,
629+
)
630+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
631+
632+
633+
class ImageArtifact(Artifact):
634+
"""A SageMaker lineage artifact representing an image.
635+
636+
Common model specific lineage traversals to discover how the image is connected
637+
to other entities.
638+
"""
639+
640+
def get_dataset(self, direction: LineageQueryDirectionEnum) -> List[Artifact]:
641+
"""Get artifacts representing dataset from the image artifact's lineage.
642+
643+
Args:
644+
direction (LineageQueryDirectionEnum): The query direction.
645+
646+
Returns:
647+
list of Artifacts: Artifacts representing an dataset.
648+
"""
649+
query_filter = LineageFilter(
650+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
651+
)
652+
query_result = LineageQuery(self.sagemaker_session).query(
653+
start_arns=[self.artifact_arn],
654+
query_filter=query_filter,
655+
direction=direction,
656+
include_edges=False,
657+
)
658+
return [vertex.to_lineage_object() for vertex in query_result.vertices]

src/sagemaker/lineage/context.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
LineageQueryDirectionEnum,
3232
)
3333
from sagemaker.lineage.artifact import Artifact
34+
from sagemaker.lineage.action import Action
35+
from sagemaker.lineage.lineage_trial_component import LineageTrialComponent
3436

3537

3638
class Context(_base_types.Record):
@@ -256,6 +258,24 @@ def list(
256258
sagemaker_session=sagemaker_session,
257259
)
258260

261+
def actions(self, direction: LineageQueryDirectionEnum) -> List[Action]:
262+
"""Get actions from the endpoint's lineage.
263+
264+
Args:
265+
direction (LineageQueryDirectionEnum): The query direction.
266+
267+
Returns:
268+
list of Actions: Actions.
269+
"""
270+
query_filter = LineageFilter(entities=[LineageEntityEnum.ACTION])
271+
query_result = LineageQuery(self.sagemaker_session).query(
272+
start_arns=[self.context_arn],
273+
query_filter=query_filter,
274+
direction=direction,
275+
include_edges=False,
276+
)
277+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
278+
259279

260280
class EndpointContext(Context):
261281
"""An Amazon SageMaker endpoint context, which is part of a SageMaker lineage."""
@@ -311,6 +331,9 @@ def training_job_arns(
311331
) -> List[str]:
312332
"""Get ARNs for all training jobs that appear in the endpoint's lineage.
313333
334+
Args:
335+
direction (LineageQueryDirectionEnum, optional): The query direction.
336+
314337
Returns:
315338
list of str: Training job ARNs.
316339
"""
@@ -333,11 +356,86 @@ def training_job_arns(
333356
training_job_arns.append(trial_component["Source"]["SourceArn"])
334357
return training_job_arns
335358

359+
def processing_job_arns(
360+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
361+
) -> List[str]:
362+
"""Get ARNs for all processing jobs that appear in the endpoint's lineage.
363+
364+
Args:
365+
direction (LineageQueryDirectionEnum, optional): The query direction.
366+
367+
Returns:
368+
list of str: Processing job ARNs.
369+
"""
370+
query_filter = LineageFilter(
371+
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.PROCESSING_JOB]
372+
)
373+
query_result = LineageQuery(self.sagemaker_session).query(
374+
start_arns=[self.context_arn],
375+
query_filter=query_filter,
376+
direction=direction,
377+
include_edges=False,
378+
)
379+
380+
processing_job_arns = []
381+
for vertex in query_result.vertices:
382+
trial_component_name = _utils.get_resource_name_from_arn(vertex.arn)
383+
trial_component = self.sagemaker_session.sagemaker_client.describe_trial_component(
384+
TrialComponentName=trial_component_name
385+
)
386+
processing_job_arns.append(trial_component["Source"]["SourceArn"])
387+
return processing_job_arns
388+
389+
def transform_jobs(
390+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
391+
) -> List[LineageTrialComponent]:
392+
"""Get ARNs for all transform jobs that appear in the endpoint's lineage.
393+
394+
Args:
395+
direction (LineageQueryDirectionEnum, optional): The query direction.
396+
397+
Returns:
398+
list of str: Transform jobs.
399+
"""
400+
query_filter = LineageFilter(
401+
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRANSFORM_JOB]
402+
)
403+
query_result = LineageQuery(self.sagemaker_session).query(
404+
start_arns=[self.context_arn],
405+
query_filter=query_filter,
406+
direction=direction,
407+
include_edges=False,
408+
)
409+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
410+
411+
def trial_components(
412+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
413+
) -> List[LineageTrialComponent]:
414+
"""Get ARNs for all trial components that appear in the endpoint's lineage.
415+
416+
Args:
417+
direction (LineageQueryDirectionEnum, optional): The query direction.
418+
419+
Returns:
420+
list of str: Trial components ARNs.
421+
"""
422+
query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT])
423+
query_result = LineageQuery(self.sagemaker_session).query(
424+
start_arns=[self.context_arn],
425+
query_filter=query_filter,
426+
direction=direction,
427+
include_edges=False,
428+
)
429+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
430+
336431
def pipeline_execution_arn(
337432
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
338433
) -> str:
339434
"""Get the ARN for the pipeline execution associated with this endpoint (if any).
340435
436+
Args:
437+
direction (LineageQueryDirectionEnum, optional): The query direction.
438+
341439
Returns:
342440
str: A pipeline execution ARN.
343441
"""

0 commit comments

Comments
 (0)