17
17
import math
18
18
19
19
from datetime import datetime
20
- from typing import Iterator , Union , Any , Optional
20
+ from typing import Iterator , Union , Any , Optional , List
21
21
22
22
from sagemaker .apiutils import _base_types , _utils
23
23
from sagemaker .lineage import _api_types
24
24
from sagemaker .lineage ._api_types import ArtifactSource , ArtifactSummary
25
- from sagemaker .lineage ._utils import get_module , _disassociate
25
+ from sagemaker .lineage .query import (
26
+ LineageQuery ,
27
+ LineageFilter ,
28
+ LineageSourceEnum ,
29
+ LineageEntityEnum ,
30
+ LineageQueryDirectionEnum ,
31
+ )
32
+ from sagemaker .lineage ._utils import get_module , _disassociate , get_resource_name_from_arn
26
33
from sagemaker .lineage .association import Association
27
34
28
35
LOGGER = logging .getLogger ("sagemaker" )
@@ -333,11 +340,13 @@ class ModelArtifact(Artifact):
333
340
"""A SageMaker lineage artifact representing a model.
334
341
335
342
Common model specific lineage traversals to discover how the model is connected
336
- to otherentities .
343
+ to other entities .
337
344
"""
338
345
346
+ from sagemaker .lineage .context import Context
347
+
339
348
def endpoints (self ) -> list :
340
- """Given a model artifact, get all associated endpoint context .
349
+ """Get association summaries for endpoints deployed with this model .
341
350
342
351
Returns:
343
352
[AssociationSummary]: A list of associations repesenting the endpoints using the model.
@@ -359,6 +368,104 @@ def endpoints(self) -> list:
359
368
]
360
369
return endpoint_context_list
361
370
371
+ def endpoint_contexts (
372
+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .DESCENDANTS
373
+ ) -> List [Context ]:
374
+ """Get contexts representing endpoints from the models's lineage.
375
+
376
+ Args:
377
+ direction (LineageQueryDirectionEnum, optional): The query direction.
378
+
379
+ Returns:
380
+ list of Contexts: Contexts representing an endpoint.
381
+ """
382
+ query_filter = LineageFilter (
383
+ entities = [LineageEntityEnum .CONTEXT ], sources = [LineageSourceEnum .ENDPOINT ]
384
+ )
385
+ query_result = LineageQuery (self .sagemaker_session ).query (
386
+ start_arns = [self .artifact_arn ],
387
+ query_filter = query_filter ,
388
+ direction = direction ,
389
+ include_edges = False ,
390
+ )
391
+
392
+ endpoint_contexts = []
393
+ for vertex in query_result .vertices :
394
+ endpoint_contexts .append (vertex .to_lineage_object ())
395
+ return endpoint_contexts
396
+
397
+ def dataset_artifacts (
398
+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .ASCENDANTS
399
+ ) -> List [Artifact ]:
400
+ """Get artifacts representing datasets from the model's lineage.
401
+
402
+ Args:
403
+ direction (LineageQueryDirectionEnum, optional): The query direction.
404
+
405
+ Returns:
406
+ list of Artifacts: Artifacts representing a dataset.
407
+ """
408
+ query_filter = LineageFilter (
409
+ entities = [LineageEntityEnum .ARTIFACT ], sources = [LineageSourceEnum .DATASET ]
410
+ )
411
+ query_result = LineageQuery (self .sagemaker_session ).query (
412
+ start_arns = [self .artifact_arn ],
413
+ query_filter = query_filter ,
414
+ direction = direction ,
415
+ include_edges = False ,
416
+ )
417
+
418
+ dataset_artifacts = []
419
+ for vertex in query_result .vertices :
420
+ dataset_artifacts .append (vertex .to_lineage_object ())
421
+ return dataset_artifacts
422
+
423
+ def training_job_arns (
424
+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .ASCENDANTS
425
+ ) -> List [str ]:
426
+ """Get ARNs for all training jobs that appear in the model's lineage.
427
+
428
+ Returns:
429
+ list of str: Training job ARNs.
430
+ """
431
+ query_filter = LineageFilter (
432
+ entities = [LineageEntityEnum .TRIAL_COMPONENT ], sources = [LineageSourceEnum .TRAINING_JOB ]
433
+ )
434
+ query_result = LineageQuery (self .sagemaker_session ).query (
435
+ start_arns = [self .artifact_arn ],
436
+ query_filter = query_filter ,
437
+ direction = direction ,
438
+ include_edges = False ,
439
+ )
440
+
441
+ training_job_arns = []
442
+ for vertex in query_result .vertices :
443
+ trial_component_name = get_resource_name_from_arn (vertex .arn )
444
+ trial_component = self .sagemaker_session .sagemaker_client .describe_trial_component (
445
+ TrialComponentName = trial_component_name
446
+ )
447
+ training_job_arns .append (trial_component ["Source" ]["SourceArn" ])
448
+ return training_job_arns
449
+
450
+ def pipeline_execution_arn (
451
+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .ASCENDANTS
452
+ ) -> str :
453
+ """Get the ARN for the pipeline execution associated with this model (if any).
454
+
455
+ Returns:
456
+ str: A pipeline execution ARN.
457
+ """
458
+ training_job_arns = self .training_job_arns (direction = direction )
459
+ for training_job_arn in training_job_arns :
460
+ tags = self .sagemaker_session .sagemaker_client .list_tags (ResourceArn = training_job_arn )[
461
+ "Tags"
462
+ ]
463
+ for tag in tags :
464
+ if tag ["Key" ] == "sagemaker:pipeline-execution-arn" :
465
+ return tag ["Value" ]
466
+
467
+ return None
468
+
362
469
363
470
class DatasetArtifact (Artifact ):
364
471
"""A SageMaker Lineage artifact representing a dataset.
@@ -367,7 +474,9 @@ class DatasetArtifact(Artifact):
367
474
connect to related entities.
368
475
"""
369
476
370
- def trained_models (self ) -> list :
477
+ from sagemaker .lineage .context import Context
478
+
479
+ def trained_models (self ) -> List [Association ]:
371
480
"""Given a dataset artifact, get associated trained models.
372
481
373
482
Returns:
@@ -387,3 +496,29 @@ def trained_models(self) -> list:
387
496
result .extend (models )
388
497
389
498
return result
499
+
500
+ def endpoint_contexts (
501
+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .DESCENDANTS
502
+ ) -> List [Context ]:
503
+ """Get contexts representing endpoints from the dataset's lineage.
504
+
505
+ Args:
506
+ direction (LineageQueryDirectionEnum, optional): The query direction.
507
+
508
+ Returns:
509
+ list of Contexts: Contexts representing an endpoint.
510
+ """
511
+ query_filter = LineageFilter (
512
+ entities = [LineageEntityEnum .CONTEXT ], sources = [LineageSourceEnum .ENDPOINT ]
513
+ )
514
+ query_result = LineageQuery (self .sagemaker_session ).query (
515
+ start_arns = [self .artifact_arn ],
516
+ query_filter = query_filter ,
517
+ direction = direction ,
518
+ include_edges = False ,
519
+ )
520
+
521
+ endpoint_contexts = []
522
+ for vertex in query_result .vertices :
523
+ endpoint_contexts .append (vertex .to_lineage_object ())
524
+ return endpoint_contexts
0 commit comments