@@ -146,7 +146,7 @@ def downstream_trials(self, sagemaker_session=None) -> list:
146
146
"""Retrieve all trial runs which that use this artifact.
147
147
148
148
Args:
149
- sagemaker_session (obj): Sagemaker Sesssion to use. If not provided a default session
149
+ sagemaker_session (obj): Sagemaker Session to use. If not provided a default session
150
150
will be created.
151
151
152
152
Returns:
@@ -159,6 +159,57 @@ def downstream_trials(self, sagemaker_session=None) -> list:
159
159
)
160
160
trial_component_arns : list = list (map (lambda x : x .destination_arn , outgoing_associations ))
161
161
162
+ return self ._get_trial_from_trial_component (trial_component_arns )
163
+
164
+ def downstream_trials_v2 (
165
+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .DESCENDANTS
166
+ ) -> list :
167
+ """Retrieve all downstream trial runs which that use this artifact by using lineage query.
168
+
169
+ Args:
170
+ direction (LineageQueryDirectionEnum, optional): The query direction.
171
+
172
+ Returns:
173
+ [Trial]: A list of SageMaker `Trial` objects.
174
+ """
175
+ query_filter = LineageFilter (entities = [LineageEntityEnum .TRIAL_COMPONENT ])
176
+ query_result = LineageQuery (self .sagemaker_session ).query (
177
+ start_arns = [self .artifact_arn ],
178
+ query_filter = query_filter ,
179
+ direction = direction ,
180
+ include_edges = False ,
181
+ )
182
+ trial_component_arns : list = list (map (lambda x : x .arn , query_result .vertices ))
183
+
184
+ return self ._get_trial_from_trial_component (trial_component_arns )
185
+
186
+ def get_upstream_trials (
187
+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .ASCENDANTS
188
+ ) -> List [str ]:
189
+ """Retrieve all upstream trial runs which that use this artifact by using lineage query.
190
+
191
+ Args:
192
+ direction (LineageQueryDirectionEnum, optional): The query direction.
193
+
194
+ Returns:
195
+ [Trial]: A list of SageMaker `Trial` objects.
196
+ """
197
+ query_filter = LineageFilter (entities = [LineageEntityEnum .TRIAL_COMPONENT ])
198
+ query_result = LineageQuery (self .sagemaker_session ).query (
199
+ start_arns = [self .artifact_arn ],
200
+ query_filter = query_filter ,
201
+ direction = direction ,
202
+ include_edges = False ,
203
+ )
204
+ trial_component_arns : list = list (map (lambda x : x .arn , query_result .vertices ))
205
+ return self ._get_trial_from_trial_component (trial_component_arns )
206
+
207
+ def _get_trial_from_trial_component (self , trial_component_arns : list ):
208
+ """Retrieve all upstream trial runs which that use the trial component arns.
209
+
210
+ Returns:
211
+ [Trial]: A list of SageMaker `Trial` objects.
212
+ """
162
213
if not trial_component_arns :
163
214
# no outgoing associations for this artifact
164
215
return []
@@ -170,7 +221,7 @@ def downstream_trials(self, sagemaker_session=None) -> list:
170
221
num_search_batches = math .ceil (len (trial_component_arns ) % max_search_by_arn )
171
222
trial_components : list = []
172
223
173
- sagemaker_session = sagemaker_session or _utils .default_session ()
224
+ sagemaker_session = self . sagemaker_session or _utils .default_session ()
174
225
sagemaker_client = sagemaker_session .sagemaker_client
175
226
176
227
for i in range (num_search_batches ):
@@ -335,6 +386,17 @@ def list(
335
386
sagemaker_session = sagemaker_session ,
336
387
)
337
388
389
+ def get_artifacts (self , s3_uri : str ):
390
+ """Return a list of artifact that use provided s3 uri.
391
+
392
+ Args:
393
+ s3_uri (str): A S3 URI.
394
+
395
+ Returns:
396
+ A list of ``Artifacts``
397
+ """
398
+ return self .sagemaker_session .sagemaker_client .list_artifacts (SourceUri = s3_uri )
399
+
338
400
339
401
class ModelArtifact (Artifact ):
340
402
"""A SageMaker lineage artifact representing a model.
@@ -349,7 +411,7 @@ def endpoints(self) -> list:
349
411
"""Get association summaries for endpoints deployed with this model.
350
412
351
413
Returns:
352
- [AssociationSummary]: A list of associations repesenting the endpoints using the model.
414
+ [AssociationSummary]: A list of associations representing the endpoints using the model.
353
415
"""
354
416
endpoint_development_actions : Iterator = Association .list (
355
417
source_arn = self .artifact_arn ,
@@ -522,3 +584,75 @@ def endpoint_contexts(
522
584
for vertex in query_result .vertices :
523
585
endpoint_contexts .append (vertex .to_lineage_object ())
524
586
return endpoint_contexts
587
+
588
+ def get_upstream_datasets (
589
+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .ASCENDANTS
590
+ ) -> List [Artifact ]:
591
+ """Get upstream artifacts representing dataset from the dataset's lineage.
592
+
593
+ Args:
594
+ direction (LineageQueryDirectionEnum, optional): The query direction.
595
+
596
+ Returns:
597
+ list of Artifacts: Artifacts representing an dataset.
598
+ """
599
+ query_filter = LineageFilter (
600
+ entities = [LineageEntityEnum .ARTIFACT ], sources = [LineageSourceEnum .DATASET ]
601
+ )
602
+ query_result = LineageQuery (self .sagemaker_session ).query (
603
+ start_arns = [self .artifact_arn ],
604
+ query_filter = query_filter ,
605
+ direction = direction ,
606
+ include_edges = False ,
607
+ )
608
+ return [vertex .to_lineage_object () for vertex in query_result .vertices ]
609
+
610
+ def get_downstream_datasets (
611
+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .DESCENDANTS
612
+ ) -> List [Artifact ]:
613
+ """Get downstream artifacts representing dataset from the dataset's lineage.
614
+
615
+ Args:
616
+ direction (LineageQueryDirectionEnum, optional): The query direction.
617
+
618
+ Returns:
619
+ list of Artifacts: Artifacts representing an dataset.
620
+ """
621
+ query_filter = LineageFilter (
622
+ entities = [LineageEntityEnum .ARTIFACT ], sources = [LineageSourceEnum .DATASET ]
623
+ )
624
+ query_result = LineageQuery (self .sagemaker_session ).query (
625
+ start_arns = [self .artifact_arn ],
626
+ query_filter = query_filter ,
627
+ direction = direction ,
628
+ include_edges = False ,
629
+ )
630
+ return [vertex .to_lineage_object () for vertex in query_result .vertices ]
631
+
632
+
633
+ class ImageArtifact (Artifact ):
634
+ """A SageMaker lineage artifact representing an image.
635
+
636
+ Common model specific lineage traversals to discover how the image is connected
637
+ to other entities.
638
+ """
639
+
640
+ def get_dataset (self , direction : LineageQueryDirectionEnum ) -> List [Artifact ]:
641
+ """Get artifacts representing dataset from the image artifact's lineage.
642
+
643
+ Args:
644
+ direction (LineageQueryDirectionEnum): The query direction.
645
+
646
+ Returns:
647
+ list of Artifacts: Artifacts representing an dataset.
648
+ """
649
+ query_filter = LineageFilter (
650
+ entities = [LineageEntityEnum .ARTIFACT ], sources = [LineageSourceEnum .DATASET ]
651
+ )
652
+ query_result = LineageQuery (self .sagemaker_session ).query (
653
+ start_arns = [self .artifact_arn ],
654
+ query_filter = query_filter ,
655
+ direction = direction ,
656
+ include_edges = False ,
657
+ )
658
+ return [vertex .to_lineage_object () for vertex in query_result .vertices ]
0 commit comments