@@ -143,10 +143,10 @@ def load(cls, artifact_arn: str, sagemaker_session=None) -> "Artifact":
143
143
return artifact
144
144
145
145
def downstream_trials (self , sagemaker_session = None ) -> list :
146
- """Retrieve all trial runs which that use this artifact.
146
+ """Use the lineage API to retrieve all downstream trials that use this artifact.
147
147
148
148
Args:
149
- sagemaker_session (obj): Sagemaker Sesssion to use. If not provided a default session
149
+ sagemaker_session (obj): Sagemaker Session to use. If not provided a default session
150
150
will be created.
151
151
152
152
Returns:
@@ -159,6 +159,51 @@ def downstream_trials(self, sagemaker_session=None) -> list:
159
159
)
160
160
trial_component_arns : list = list (map (lambda x : x .destination_arn , outgoing_associations ))
161
161
162
+ return self ._get_trial_from_trial_component (trial_component_arns )
163
+
164
+ def downstream_trials_v2 (self ) -> list :
165
+ """Use a lineage query to retrieve all downstream trials that use this artifact.
166
+
167
+ Returns:
168
+ [Trial]: A list of SageMaker `Trial` objects.
169
+ """
170
+ return self ._trials (direction = LineageQueryDirectionEnum .DESCENDANTS )
171
+
172
+ def upstream_trials (self ) -> List :
173
+ """Use the lineage query to retrieve all upstream trials that use this artifact.
174
+
175
+ Returns:
176
+ [Trial]: A list of SageMaker `Trial` objects.
177
+ """
178
+ return self ._trials (direction = LineageQueryDirectionEnum .ASCENDANTS )
179
+
180
+ def _trials (
181
+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .BOTH
182
+ ) -> List :
183
+ """Use the lineage query to retrieve all trials that use this artifact.
184
+
185
+ Args:
186
+ direction (LineageQueryDirectionEnum, optional): The query direction.
187
+
188
+ Returns:
189
+ [Trial]: A list of SageMaker `Trial` objects.
190
+ """
191
+ query_filter = LineageFilter (entities = [LineageEntityEnum .TRIAL_COMPONENT ])
192
+ query_result = LineageQuery (self .sagemaker_session ).query (
193
+ start_arns = [self .artifact_arn ],
194
+ query_filter = query_filter ,
195
+ direction = direction ,
196
+ include_edges = False ,
197
+ )
198
+ trial_component_arns : list = list (map (lambda x : x .arn , query_result .vertices ))
199
+ return self ._get_trial_from_trial_component (trial_component_arns )
200
+
201
+ def _get_trial_from_trial_component (self , trial_component_arns : list ) -> List :
202
+ """Retrieve all upstream trial runs which that use the trial component arns.
203
+
204
+ Returns:
205
+ [Trial]: A list of SageMaker `Trial` objects.
206
+ """
162
207
if not trial_component_arns :
163
208
# no outgoing associations for this artifact
164
209
return []
@@ -170,7 +215,7 @@ def downstream_trials(self, sagemaker_session=None) -> list:
170
215
num_search_batches = math .ceil (len (trial_component_arns ) % max_search_by_arn )
171
216
trial_components : list = []
172
217
173
- sagemaker_session = sagemaker_session or _utils .default_session ()
218
+ sagemaker_session = self . sagemaker_session or _utils .default_session ()
174
219
sagemaker_client = sagemaker_session .sagemaker_client
175
220
176
221
for i in range (num_search_batches ):
@@ -335,6 +380,17 @@ def list(
335
380
sagemaker_session = sagemaker_session ,
336
381
)
337
382
383
+ def s3_uri_artifacts (self , s3_uri : str ):
384
+ """Retrieve a list of artifacts that use provided s3 uri.
385
+
386
+ Args:
387
+ s3_uri (str): A S3 URI.
388
+
389
+ Returns:
390
+ A list of ``Artifacts``
391
+ """
392
+ return self .sagemaker_session .sagemaker_client .list_artifacts (SourceUri = s3_uri )
393
+
338
394
339
395
class ModelArtifact (Artifact ):
340
396
"""A SageMaker lineage artifact representing a model.
@@ -349,7 +405,7 @@ def endpoints(self) -> list:
349
405
"""Get association summaries for endpoints deployed with this model.
350
406
351
407
Returns:
352
- [AssociationSummary]: A list of associations repesenting the endpoints using the model.
408
+ [AssociationSummary]: A list of associations representing the endpoints using the model.
353
409
"""
354
410
endpoint_development_actions : Iterator = Association .list (
355
411
source_arn = self .artifact_arn ,
@@ -522,3 +578,69 @@ def endpoint_contexts(
522
578
for vertex in query_result .vertices :
523
579
endpoint_contexts .append (vertex .to_lineage_object ())
524
580
return endpoint_contexts
581
+
582
+ def upstream_datasets (self ) -> List [Artifact ]:
583
+ """Use the lineage query to retrieve upstream artifacts that use this dataset artifact.
584
+
585
+ Returns:
586
+ list of Artifacts: Artifacts representing an dataset.
587
+ """
588
+ return self ._datasets (direction = LineageQueryDirectionEnum .ASCENDANTS )
589
+
590
+ def downstream_datasets (self ) -> List [Artifact ]:
591
+ """Use the lineage query to retrieve downstream artifacts that use this dataset.
592
+
593
+ Returns:
594
+ list of Artifacts: Artifacts representing an dataset.
595
+ """
596
+ return self ._datasets (direction = LineageQueryDirectionEnum .DESCENDANTS )
597
+
598
+ def _datasets (
599
+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .BOTH
600
+ ) -> List [Artifact ]:
601
+ """Use the lineage query to retrieve all artifacts that use this dataset.
602
+
603
+ Args:
604
+ direction (LineageQueryDirectionEnum, optional): The query direction.
605
+
606
+ Returns:
607
+ list of Artifacts: Artifacts representing an dataset.
608
+ """
609
+ query_filter = LineageFilter (
610
+ entities = [LineageEntityEnum .ARTIFACT ], sources = [LineageSourceEnum .DATASET ]
611
+ )
612
+ query_result = LineageQuery (self .sagemaker_session ).query (
613
+ start_arns = [self .artifact_arn ],
614
+ query_filter = query_filter ,
615
+ direction = direction ,
616
+ include_edges = False ,
617
+ )
618
+ return [vertex .to_lineage_object () for vertex in query_result .vertices ]
619
+
620
+
621
+ class ImageArtifact (Artifact ):
622
+ """A SageMaker lineage artifact representing an image.
623
+
624
+ Common model specific lineage traversals to discover how the image is connected
625
+ to other entities.
626
+ """
627
+
628
+ def datasets (self , direction : LineageQueryDirectionEnum ) -> List [Artifact ]:
629
+ """Use the lineage query to retrieve datasets that use this image artifact.
630
+
631
+ Args:
632
+ direction (LineageQueryDirectionEnum): The query direction.
633
+
634
+ Returns:
635
+ list of Artifacts: Artifacts representing a dataset.
636
+ """
637
+ query_filter = LineageFilter (
638
+ entities = [LineageEntityEnum .ARTIFACT ], sources = [LineageSourceEnum .DATASET ]
639
+ )
640
+ query_result = LineageQuery (self .sagemaker_session ).query (
641
+ start_arns = [self .artifact_arn ],
642
+ query_filter = query_filter ,
643
+ direction = direction ,
644
+ include_edges = False ,
645
+ )
646
+ return [vertex .to_lineage_object () for vertex in query_result .vertices ]
0 commit comments