Skip to content

Commit ddda863

Browse files
committed
feature: Add support for SageMaker lineage queries in artifact, context and trial component
1 parent 975e031 commit ddda863

File tree

14 files changed

+787
-21
lines changed

14 files changed

+787
-21
lines changed

src/sagemaker/lineage/action.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Optional, Iterator
1717
from datetime import datetime
1818

19-
from sagemaker import Session
19+
from sagemaker.session import Session
2020
from sagemaker.apiutils import _base_types
2121
from sagemaker.lineage import _api_types, _utils
2222
from sagemaker.lineage._api_types import ActionSource, ActionSummary
@@ -116,7 +116,7 @@ def delete(self, disassociate: bool = False):
116116
self._invoke_api(self._boto_delete_method, self._boto_delete_members)
117117

118118
@classmethod
119-
def load(cls, action_name: str, sagemaker_session: Session = None) -> "Action":
119+
def load(cls, action_name: str, sagemaker_session=None) -> "Action":
120120
"""Load an existing action and return an ``Action`` object representing it.
121121
122122
Args:

src/sagemaker/lineage/artifact.py

Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,10 @@ def load(cls, artifact_arn: str, sagemaker_session=None) -> "Artifact":
143143
return artifact
144144

145145
def downstream_trials(self, sagemaker_session=None) -> list:
146-
"""Retrieve all trial runs which that use this artifact.
146+
"""Use the lineage API to retrieve all downstream trials that use this artifact.
147147
148148
Args:
149-
sagemaker_session (obj): Sagemaker Sesssion to use. If not provided a default session
149+
sagemaker_session (obj): Sagemaker Session to use. If not provided a default session
150150
will be created.
151151
152152
Returns:
@@ -159,6 +159,51 @@ def downstream_trials(self, sagemaker_session=None) -> list:
159159
)
160160
trial_component_arns: list = list(map(lambda x: x.destination_arn, outgoing_associations))
161161

162+
return self._get_trial_from_trial_component(trial_component_arns)
163+
164+
def downstream_trials_v2(self) -> list:
165+
"""Use a lineage query to retrieve all downstream trials that use this artifact.
166+
167+
Returns:
168+
[Trial]: A list of SageMaker `Trial` objects.
169+
"""
170+
return self._trials(direction=LineageQueryDirectionEnum.DESCENDANTS)
171+
172+
def upstream_trials(self) -> List:
173+
"""Use the lineage query to retrieve all upstream trials that use this artifact.
174+
175+
Returns:
176+
[Trial]: A list of SageMaker `Trial` objects.
177+
"""
178+
return self._trials(direction=LineageQueryDirectionEnum.ASCENDANTS)
179+
180+
def _trials(
181+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH
182+
) -> List:
183+
"""Use the lineage query to retrieve all trials that use this artifact.
184+
185+
Args:
186+
direction (LineageQueryDirectionEnum, optional): The query direction.
187+
188+
Returns:
189+
[Trial]: A list of SageMaker `Trial` objects.
190+
"""
191+
query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT])
192+
query_result = LineageQuery(self.sagemaker_session).query(
193+
start_arns=[self.artifact_arn],
194+
query_filter=query_filter,
195+
direction=direction,
196+
include_edges=False,
197+
)
198+
trial_component_arns: list = list(map(lambda x: x.arn, query_result.vertices))
199+
return self._get_trial_from_trial_component(trial_component_arns)
200+
201+
def _get_trial_from_trial_component(self, trial_component_arns: list) -> List:
202+
"""Retrieve all upstream trial runs which that use the trial component arns.
203+
204+
Returns:
205+
[Trial]: A list of SageMaker `Trial` objects.
206+
"""
162207
if not trial_component_arns:
163208
# no outgoing associations for this artifact
164209
return []
@@ -170,7 +215,7 @@ def downstream_trials(self, sagemaker_session=None) -> list:
170215
num_search_batches = math.ceil(len(trial_component_arns) % max_search_by_arn)
171216
trial_components: list = []
172217

173-
sagemaker_session = sagemaker_session or _utils.default_session()
218+
sagemaker_session = self.sagemaker_session or _utils.default_session()
174219
sagemaker_client = sagemaker_session.sagemaker_client
175220

176221
for i in range(num_search_batches):
@@ -335,6 +380,17 @@ def list(
335380
sagemaker_session=sagemaker_session,
336381
)
337382

383+
def s3_uri_artifacts(self, s3_uri: str):
384+
"""Retrieve a list of artifacts that use provided s3 uri.
385+
386+
Args:
387+
s3_uri (str): A S3 URI.
388+
389+
Returns:
390+
A list of ``Artifacts``
391+
"""
392+
return self.sagemaker_session.sagemaker_client.list_artifacts(SourceUri=s3_uri)
393+
338394

339395
class ModelArtifact(Artifact):
340396
"""A SageMaker lineage artifact representing a model.
@@ -349,7 +405,7 @@ def endpoints(self) -> list:
349405
"""Get association summaries for endpoints deployed with this model.
350406
351407
Returns:
352-
[AssociationSummary]: A list of associations repesenting the endpoints using the model.
408+
[AssociationSummary]: A list of associations representing the endpoints using the model.
353409
"""
354410
endpoint_development_actions: Iterator = Association.list(
355411
source_arn=self.artifact_arn,
@@ -522,3 +578,69 @@ def endpoint_contexts(
522578
for vertex in query_result.vertices:
523579
endpoint_contexts.append(vertex.to_lineage_object())
524580
return endpoint_contexts
581+
582+
def upstream_datasets(self) -> List[Artifact]:
583+
"""Use the lineage query to retrieve upstream artifacts that use this dataset artifact.
584+
585+
Returns:
586+
list of Artifacts: Artifacts representing an dataset.
587+
"""
588+
return self._datasets(direction=LineageQueryDirectionEnum.ASCENDANTS)
589+
590+
def downstream_datasets(self) -> List[Artifact]:
591+
"""Use the lineage query to retrieve downstream artifacts that use this dataset.
592+
593+
Returns:
594+
list of Artifacts: Artifacts representing an dataset.
595+
"""
596+
return self._datasets(direction=LineageQueryDirectionEnum.DESCENDANTS)
597+
598+
def _datasets(
599+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH
600+
) -> List[Artifact]:
601+
"""Use the lineage query to retrieve all artifacts that use this dataset.
602+
603+
Args:
604+
direction (LineageQueryDirectionEnum, optional): The query direction.
605+
606+
Returns:
607+
list of Artifacts: Artifacts representing an dataset.
608+
"""
609+
query_filter = LineageFilter(
610+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
611+
)
612+
query_result = LineageQuery(self.sagemaker_session).query(
613+
start_arns=[self.artifact_arn],
614+
query_filter=query_filter,
615+
direction=direction,
616+
include_edges=False,
617+
)
618+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
619+
620+
621+
class ImageArtifact(Artifact):
622+
"""A SageMaker lineage artifact representing an image.
623+
624+
Common model specific lineage traversals to discover how the image is connected
625+
to other entities.
626+
"""
627+
628+
def datasets(self, direction: LineageQueryDirectionEnum) -> List[Artifact]:
629+
"""Use the lineage query to retrieve datasets that use this image artifact.
630+
631+
Args:
632+
direction (LineageQueryDirectionEnum): The query direction.
633+
634+
Returns:
635+
list of Artifacts: Artifacts representing a dataset.
636+
"""
637+
query_filter = LineageFilter(
638+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
639+
)
640+
query_result = LineageQuery(self.sagemaker_session).query(
641+
start_arns=[self.artifact_arn],
642+
query_filter=query_filter,
643+
direction=direction,
644+
include_edges=False,
645+
)
646+
return [vertex.to_lineage_object() for vertex in query_result.vertices]

src/sagemaker/lineage/context.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
LineageQueryDirectionEnum,
3232
)
3333
from sagemaker.lineage.artifact import Artifact
34+
from sagemaker.lineage.action import Action
35+
from sagemaker.lineage.lineage_trial_component import LineageTrialComponent
3436

3537

3638
class Context(_base_types.Record):
@@ -256,12 +258,30 @@ def list(
256258
sagemaker_session=sagemaker_session,
257259
)
258260

261+
def actions(self, direction: LineageQueryDirectionEnum) -> List[Action]:
262+
"""Use the lineage query to retrieve actions that use this context.
263+
264+
Args:
265+
direction (LineageQueryDirectionEnum): The query direction.
266+
267+
Returns:
268+
list of Actions: Actions.
269+
"""
270+
query_filter = LineageFilter(entities=[LineageEntityEnum.ACTION])
271+
query_result = LineageQuery(self.sagemaker_session).query(
272+
start_arns=[self.context_arn],
273+
query_filter=query_filter,
274+
direction=direction,
275+
include_edges=False,
276+
)
277+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
278+
259279

260280
class EndpointContext(Context):
261281
"""An Amazon SageMaker endpoint context, which is part of a SageMaker lineage."""
262282

263283
def models(self) -> List[association.Association]:
264-
"""Get all models deployed by all endpoint versions of the endpoint.
284+
"""Use Lineage API to get all models deployed by this endpoint.
265285
266286
Returns:
267287
list of Associations: Associations that destination represents an endpoint's model.
@@ -286,7 +306,7 @@ def models(self) -> List[association.Association]:
286306
def models_v2(
287307
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
288308
) -> List[Artifact]:
289-
"""Get artifacts representing models from the context lineage by querying lineage data.
309+
"""Use the lineage query to retrieve downstream model artifacts that use this endpoint.
290310
291311
Args:
292312
direction (LineageQueryDirectionEnum, optional): The query direction.
@@ -335,7 +355,7 @@ def models_v2(
335355
def dataset_artifacts(
336356
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
337357
) -> List[Artifact]:
338-
"""Get artifacts representing datasets from the endpoint's lineage.
358+
"""Use the lineage query to retrieve datasets that use this endpoint.
339359
340360
Args:
341361
direction (LineageQueryDirectionEnum, optional): The query direction.
@@ -360,6 +380,9 @@ def training_job_arns(
360380
) -> List[str]:
361381
"""Get ARNs for all training jobs that appear in the endpoint's lineage.
362382
383+
Args:
384+
direction (LineageQueryDirectionEnum, optional): The query direction.
385+
363386
Returns:
364387
list of str: Training job ARNs.
365388
"""
@@ -382,11 +405,78 @@ def training_job_arns(
382405
training_job_arns.append(trial_component["Source"]["SourceArn"])
383406
return training_job_arns
384407

408+
def processing_jobs(
409+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
410+
) -> List[LineageTrialComponent]:
411+
"""Use the lineage query to retrieve processing jobs that use this endpoint.
412+
413+
Args:
414+
direction (LineageQueryDirectionEnum, optional): The query direction.
415+
416+
Returns:
417+
list of LineageTrialComponent: Lineage trial component that represent Processing jobs.
418+
"""
419+
query_filter = LineageFilter(
420+
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.PROCESSING_JOB]
421+
)
422+
query_result = LineageQuery(self.sagemaker_session).query(
423+
start_arns=[self.context_arn],
424+
query_filter=query_filter,
425+
direction=direction,
426+
include_edges=False,
427+
)
428+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
429+
430+
def transform_jobs(
431+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
432+
) -> List[LineageTrialComponent]:
433+
"""Use the lineage query to retrieve transform jobs that use this endpoint.
434+
435+
Args:
436+
direction (LineageQueryDirectionEnum, optional): The query direction.
437+
438+
Returns:
439+
list of LineageTrialComponent: Lineage trial component that represent Transform jobs.
440+
"""
441+
query_filter = LineageFilter(
442+
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRANSFORM_JOB]
443+
)
444+
query_result = LineageQuery(self.sagemaker_session).query(
445+
start_arns=[self.context_arn],
446+
query_filter=query_filter,
447+
direction=direction,
448+
include_edges=False,
449+
)
450+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
451+
452+
def trial_components(
453+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
454+
) -> List[LineageTrialComponent]:
455+
"""Use the lineage query to retrieve trial components that use this endpoint.
456+
457+
Args:
458+
direction (LineageQueryDirectionEnum, optional): The query direction.
459+
460+
Returns:
461+
list of LineageTrialComponent: Lineage trial component.
462+
"""
463+
query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT])
464+
query_result = LineageQuery(self.sagemaker_session).query(
465+
start_arns=[self.context_arn],
466+
query_filter=query_filter,
467+
direction=direction,
468+
include_edges=False,
469+
)
470+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
471+
385472
def pipeline_execution_arn(
386473
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
387474
) -> str:
388475
"""Get the ARN for the pipeline execution associated with this endpoint (if any).
389476
477+
Args:
478+
direction (LineageQueryDirectionEnum, optional): The query direction.
479+
390480
Returns:
391481
str: A pipeline execution ARN.
392482
"""

0 commit comments

Comments
 (0)