Skip to content

Commit 4623bd6

Browse files
nishkrisknikure
authored andcommitted
feature: Add support for listing executions from pipeline (aws#930)
1 parent d83b6fd commit 4623bd6

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

src/sagemaker/workflow/pipeline.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,46 @@ def _interpolate_step_collection_name_in_depends_on(self, step_requests: list):
396396
)
397397
self._interpolate_step_collection_name_in_depends_on(sub_step_requests)
398398

399+
def list_executions(
400+
self,
401+
sort_by: str = None,
402+
sort_order: str = None,
403+
max_results: int = None,
404+
next_token: str = None,
405+
) -> Dict[str, Any]:
406+
"""Lists a pipeline's executions.
407+
408+
Args:
409+
sort_by (str): The field by which to sort results(CreationTime/PipelineExecutionArn).
410+
sort_order (str): The sort order for results (Ascending/Descending).
411+
max_results (int): The maximum number of pipeline executions to return in the response.
412+
next_token (str): If the result of the previous ListPipelineExecutions request was
413+
truncated, the response includes a NextToken. To retrieve the next set of pipeline
414+
executions, use the token in the next request.
415+
416+
Returns:
417+
List of Pipeline Execution Summaries. See
418+
boto3 client list_pipeline_executions
419+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_pipeline_executions
420+
"""
421+
kwargs = dict(PipelineName=self.name)
422+
update_args(
423+
kwargs,
424+
SortBy=sort_by,
425+
SortOrder=sort_order,
426+
NextToken=next_token,
427+
MaxResults=max_results,
428+
)
429+
response = self.sagemaker_session.sagemaker_client.list_pipeline_executions(**kwargs)
430+
431+
# Return only PipelineExecutionSummaries and NextToken from the list_pipeline_executions
432+
# response
433+
return {
434+
key: response[key]
435+
for key in ["PipelineExecutionSummaries", "NextToken"]
436+
if key in response
437+
}
438+
399439

400440
def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
401441
"""Formats start parameter overrides as a list of dicts.

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,31 @@ def test_pipeline_disable_experiment_config():
610610
)
611611

612612

613+
def test_pipeline_list_executions(sagemaker_session_mock):
614+
sagemaker_session_mock.sagemaker_client.list_pipeline_executions.return_value = {
615+
"PipelineExecutionSummaries": [Mock()],
616+
"ResponseMetadata": "metadata",
617+
}
618+
pipeline = Pipeline(
619+
name="MyPipeline",
620+
parameters=[ParameterString("alpha", "beta"), ParameterString("gamma", "delta")],
621+
steps=[],
622+
sagemaker_session=sagemaker_session_mock,
623+
)
624+
executions = pipeline.list_executions()
625+
assert len(executions) == 1
626+
assert len(executions["PipelineExecutionSummaries"]) == 1
627+
sagemaker_session_mock.sagemaker_client.list_pipeline_executions.return_value = {
628+
"PipelineExecutionSummaries": [Mock(), Mock()],
629+
"NextToken": "token",
630+
"ResponseMetadata": "metadata",
631+
}
632+
executions = pipeline.list_executions()
633+
assert len(executions) == 2
634+
assert len(executions["PipelineExecutionSummaries"]) == 2
635+
assert executions["NextToken"] == "token"
636+
637+
613638
def test_pipeline_execution_basics(sagemaker_session_mock):
614639
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = {
615640
"PipelineExecutionArn": "my:arn"

0 commit comments

Comments
 (0)