Skip to content

Commit f45ba86

Browse files
nishkrisShegufta Ahsan
authored andcommitted
feature: Default selective execution source pipeline to latest pipeline execution (aws#931)
1 parent 1283d3a commit f45ba86

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

src/sagemaker/workflow/pipeline.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@ def start(
332332
A `_PipelineExecution` instance, if successful.
333333
"""
334334
if selective_execution_config is not None:
335+
if selective_execution_config.source_pipeline_execution_arn is None:
336+
selective_execution_config.source_pipeline_execution_arn = (
337+
self._get_latest_execution_arn()
338+
)
335339
selective_execution_config = selective_execution_config.to_request()
336340

337341
kwargs = dict(PipelineName=self.name)
@@ -436,6 +440,17 @@ def list_executions(
436440
if key in response
437441
}
438442

443+
def _get_latest_execution_arn(self):
444+
"""Retrieves the latest execution of this pipeline"""
445+
response = self.list_executions(
446+
sort_by="CreationTime",
447+
sort_order="Descending",
448+
max_results=1,
449+
)
450+
if response["PipelineExecutionSummaries"]:
451+
return response["PipelineExecutionSummaries"][0]["PipelineExecutionArn"]
452+
return None
453+
439454

440455
def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
441456
"""Formats start parameter overrides as a list of dicts.

src/sagemaker/workflow/selective_execution_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
class SelectiveExecutionConfig:
2020
"""Selective execution config config for SageMaker pipeline."""
2121

22-
def __init__(self, source_pipeline_execution_arn: str, selected_steps: List[str]):
22+
def __init__(self, selected_steps: List[str], source_pipeline_execution_arn: str = None):
2323
"""Create a SelectiveExecutionConfig
2424
2525
Args:

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,19 @@ def test_pipeline_start(sagemaker_session_mock):
423423
PipelineName="MyPipeline", PipelineParameters=[{"Name": "alpha", "Value": "epsilon"}]
424424
)
425425

426+
427+
def test_pipeline_start_selective_execution(sagemaker_session_mock):
428+
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = {
429+
"PipelineExecutionArn": "my:arn"
430+
}
431+
pipeline = Pipeline(
432+
name="MyPipeline",
433+
parameters=[],
434+
steps=[],
435+
sagemaker_session=sagemaker_session_mock,
436+
)
437+
438+
# Case 1: Happy path
426439
selective_execution_config = SelectiveExecutionConfig(
427440
source_pipeline_execution_arn="foo-arn", selected_steps=["step-1", "step-2", "step-3"]
428441
)
@@ -439,6 +452,37 @@ def test_pipeline_start(sagemaker_session_mock):
439452
},
440453
)
441454

455+
# Case 2: Start selective execution without SourcePipelineExecutionArn
456+
sagemaker_session_mock.sagemaker_client.list_pipeline_executions.return_value = {
457+
"PipelineExecutionSummaries": [
458+
{
459+
"PipelineExecutionArn": "my:latest:execution:arn",
460+
"PipelineExecutionDisplayName": "Latest",
461+
}
462+
]
463+
}
464+
selective_execution_config = SelectiveExecutionConfig(
465+
selected_steps=["step-1", "step-2", "step-3"]
466+
)
467+
pipeline.start(selective_execution_config=selective_execution_config)
468+
sagemaker_session_mock.sagemaker_client.list_pipeline_executions.assert_called_with(
469+
PipelineName="MyPipeline",
470+
SortBy="CreationTime",
471+
SortOrder="Descending",
472+
MaxResults=1,
473+
)
474+
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with(
475+
PipelineName="MyPipeline",
476+
SelectiveExecutionConfig={
477+
"SelectedSteps": [
478+
{"StepName": "step-1"},
479+
{"StepName": "step-2"},
480+
{"StepName": "step-3"},
481+
],
482+
"SourcePipelineExecutionArn": "my:latest:execution:arn",
483+
},
484+
)
485+
442486

443487
def test_pipeline_basic():
444488
parameter = ParameterString("MyStr")

0 commit comments

Comments
 (0)