Skip to content

Commit 46957fd

Browse files
navinsoniPayton Staub
authored andcommitted
feature: Add support for SageMaker lineage queries
Co-authored-by: Payton Staub <[email protected]>
1 parent 601c94e commit 46957fd

File tree

16 files changed

+9868
-19
lines changed

16 files changed

+9868
-19
lines changed

src/sagemaker/lineage/_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,15 @@ def get_module(module_name):
5252
return import_module(module_name)
5353
except ImportError:
5454
raise Exception("Cannot import module {}, please try again.".format(module_name))
55+
56+
57+
def get_resource_name_from_arn(arn):
58+
"""Extract the resource name from an ARN string.
59+
60+
Args:
61+
arn (str): An ARN.
62+
63+
Returns:
64+
str: The resource name.
65+
"""
66+
return arn.split(":", 5)[5].split("/", 1)[1]

src/sagemaker/lineage/artifact.py

Lines changed: 140 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,19 @@
1717
import math
1818

1919
from datetime import datetime
20-
from typing import Iterator, Union, Any, Optional
20+
from typing import Iterator, Union, Any, Optional, List
2121

2222
from sagemaker.apiutils import _base_types, _utils
2323
from sagemaker.lineage import _api_types
2424
from sagemaker.lineage._api_types import ArtifactSource, ArtifactSummary
25-
from sagemaker.lineage._utils import get_module, _disassociate
25+
from sagemaker.lineage.query import (
26+
LineageQuery,
27+
LineageFilter,
28+
LineageSourceEnum,
29+
LineageEntityEnum,
30+
LineageQueryDirectionEnum,
31+
)
32+
from sagemaker.lineage._utils import get_module, _disassociate, get_resource_name_from_arn
2633
from sagemaker.lineage.association import Association
2734

2835
LOGGER = logging.getLogger("sagemaker")
@@ -333,11 +340,13 @@ class ModelArtifact(Artifact):
333340
"""A SageMaker lineage artifact representing a model.
334341
335342
Common model specific lineage traversals to discover how the model is connected
336-
to otherentities.
343+
to other entities.
337344
"""
338345

346+
from sagemaker.lineage.context import Context
347+
339348
def endpoints(self) -> list:
340-
"""Given a model artifact, get all associated endpoint context.
349+
"""Get association summaries for endpoints deployed with this model.
341350
342351
Returns:
343352
[AssociationSummary]: A list of associations repesenting the endpoints using the model.
@@ -359,6 +368,104 @@ def endpoints(self) -> list:
359368
]
360369
return endpoint_context_list
361370

371+
def endpoint_contexts(
372+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
373+
) -> List[Context]:
374+
"""Get contexts representing endpoints from the models's lineage.
375+
376+
Args:
377+
direction (LineageQueryDirectionEnum, optional): The query direction.
378+
379+
Returns:
380+
list of Contexts: Contexts representing an endpoint.
381+
"""
382+
query_filter = LineageFilter(
383+
entities=[LineageEntityEnum.CONTEXT], sources=[LineageSourceEnum.ENDPOINT]
384+
)
385+
query_result = LineageQuery(self.sagemaker_session).query(
386+
start_arns=[self.artifact_arn],
387+
query_filter=query_filter,
388+
direction=direction,
389+
include_edges=False,
390+
)
391+
392+
endpoint_contexts = []
393+
for vertex in query_result.vertices:
394+
endpoint_contexts.append(vertex.to_lineage_object())
395+
return endpoint_contexts
396+
397+
def dataset_artifacts(
398+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
399+
) -> List[Artifact]:
400+
"""Get artifacts representing datasets from the model's lineage.
401+
402+
Args:
403+
direction (LineageQueryDirectionEnum, optional): The query direction.
404+
405+
Returns:
406+
list of Artifacts: Artifacts representing a dataset.
407+
"""
408+
query_filter = LineageFilter(
409+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
410+
)
411+
query_result = LineageQuery(self.sagemaker_session).query(
412+
start_arns=[self.artifact_arn],
413+
query_filter=query_filter,
414+
direction=direction,
415+
include_edges=False,
416+
)
417+
418+
dataset_artifacts = []
419+
for vertex in query_result.vertices:
420+
dataset_artifacts.append(vertex.to_lineage_object())
421+
return dataset_artifacts
422+
423+
def training_job_arns(
424+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
425+
) -> List[str]:
426+
"""Get ARNs for all training jobs that appear in the model's lineage.
427+
428+
Returns:
429+
list of str: Training job ARNs.
430+
"""
431+
query_filter = LineageFilter(
432+
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB]
433+
)
434+
query_result = LineageQuery(self.sagemaker_session).query(
435+
start_arns=[self.artifact_arn],
436+
query_filter=query_filter,
437+
direction=direction,
438+
include_edges=False,
439+
)
440+
441+
training_job_arns = []
442+
for vertex in query_result.vertices:
443+
trial_component_name = get_resource_name_from_arn(vertex.arn)
444+
trial_component = self.sagemaker_session.sagemaker_client.describe_trial_component(
445+
TrialComponentName=trial_component_name
446+
)
447+
training_job_arns.append(trial_component["Source"]["SourceArn"])
448+
return training_job_arns
449+
450+
def pipeline_execution_arn(
451+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
452+
) -> str:
453+
"""Get the ARN for the pipeline execution associated with this model (if any).
454+
455+
Returns:
456+
str: A pipeline execution ARN.
457+
"""
458+
training_job_arns = self.training_job_arns(direction=direction)
459+
for training_job_arn in training_job_arns:
460+
tags = self.sagemaker_session.sagemaker_client.list_tags(ResourceArn=training_job_arn)[
461+
"Tags"
462+
]
463+
for tag in tags:
464+
if tag["Key"] == "sagemaker:pipeline-execution-arn":
465+
return tag["Value"]
466+
467+
return None
468+
362469

363470
class DatasetArtifact(Artifact):
364471
"""A SageMaker Lineage artifact representing a dataset.
@@ -367,7 +474,9 @@ class DatasetArtifact(Artifact):
367474
connect to related entities.
368475
"""
369476

370-
def trained_models(self) -> list:
477+
from sagemaker.lineage.context import Context
478+
479+
def trained_models(self) -> List[Association]:
371480
"""Given a dataset artifact, get associated trained models.
372481
373482
Returns:
@@ -387,3 +496,29 @@ def trained_models(self) -> list:
387496
result.extend(models)
388497

389498
return result
499+
500+
def endpoint_contexts(
501+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
502+
) -> List[Context]:
503+
"""Get contexts representing endpoints from the dataset's lineage.
504+
505+
Args:
506+
direction (LineageQueryDirectionEnum, optional): The query direction.
507+
508+
Returns:
509+
list of Contexts: Contexts representing an endpoint.
510+
"""
511+
query_filter = LineageFilter(
512+
entities=[LineageEntityEnum.CONTEXT], sources=[LineageSourceEnum.ENDPOINT]
513+
)
514+
query_result = LineageQuery(self.sagemaker_session).query(
515+
start_arns=[self.artifact_arn],
516+
query_filter=query_filter,
517+
direction=direction,
518+
include_edges=False,
519+
)
520+
521+
endpoint_contexts = []
522+
for vertex in query_result.vertices:
523+
endpoint_contexts.append(vertex.to_lineage_object())
524+
return endpoint_contexts

src/sagemaker/lineage/context.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from datetime import datetime
17-
from typing import Iterator, Optional
17+
from typing import Iterator, Optional, List
1818

1919
from sagemaker.apiutils import _base_types
2020
from sagemaker.lineage import (
@@ -23,6 +23,14 @@
2323
association,
2424
)
2525
from sagemaker.lineage._api_types import ContextSummary
26+
from sagemaker.lineage.query import (
27+
LineageQuery,
28+
LineageFilter,
29+
LineageSourceEnum,
30+
LineageEntityEnum,
31+
LineageQueryDirectionEnum,
32+
)
33+
from sagemaker.lineage.artifact import Artifact
2634

2735

2836
class Context(_base_types.Record):
@@ -252,7 +260,7 @@ def list(
252260
class EndpointContext(Context):
253261
"""An Amazon SageMaker endpoint context, which is part of a SageMaker lineage."""
254262

255-
def models(self) -> list:
263+
def models(self) -> List[association.Association]:
256264
"""Get all models deployed by all endpoint versions of the endpoint.
257265
258266
Returns:
@@ -274,3 +282,72 @@ def models(self) -> list:
274282
)
275283
]
276284
return model_list
285+
286+
def dataset_artifacts(
287+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
288+
) -> List[Artifact]:
289+
"""Get artifacts representing datasets from the endpoint's lineage.
290+
291+
Args:
292+
direction (LineageQueryDirectionEnum, optional): The query direction.
293+
294+
Returns:
295+
list of Artifacts: Artifacts representing a dataset.
296+
"""
297+
query_filter = LineageFilter(
298+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
299+
)
300+
query_result = LineageQuery(self.sagemaker_session).query(
301+
start_arns=[self.context_arn],
302+
query_filter=query_filter,
303+
direction=direction,
304+
include_edges=False,
305+
)
306+
307+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
308+
309+
def training_job_arns(
310+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
311+
) -> List[str]:
312+
"""Get ARNs for all training jobs that appear in the endpoint's lineage.
313+
314+
Returns:
315+
list of str: Training job ARNs.
316+
"""
317+
query_filter = LineageFilter(
318+
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB]
319+
)
320+
query_result = LineageQuery(self.sagemaker_session).query(
321+
start_arns=[self.context_arn],
322+
query_filter=query_filter,
323+
direction=direction,
324+
include_edges=False,
325+
)
326+
327+
training_job_arns = []
328+
for vertex in query_result.vertices:
329+
trial_component_name = _utils.get_resource_name_from_arn(vertex.arn)
330+
trial_component = self.sagemaker_session.sagemaker_client.describe_trial_component(
331+
TrialComponentName=trial_component_name
332+
)
333+
training_job_arns.append(trial_component["Source"]["SourceArn"])
334+
return training_job_arns
335+
336+
def pipeline_execution_arn(
337+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
338+
) -> str:
339+
"""Get the ARN for the pipeline execution associated with this endpoint (if any).
340+
341+
Returns:
342+
str: A pipeline execution ARN.
343+
"""
344+
training_job_arns = self.training_job_arns(direction=direction)
345+
for training_job_arn in training_job_arns:
346+
tags = self.sagemaker_session.sagemaker_client.list_tags(ResourceArn=training_job_arn)[
347+
"Tags"
348+
]
349+
for tag in tags:
350+
if tag["Key"] == "sagemaker:pipeline-execution-arn":
351+
return tag["Value"]
352+
353+
return None

0 commit comments

Comments
 (0)