Skip to content

Commit ab00bac

Browse files
nishkrisknikure
authored andcommitted
feature: Default selective execution source pipeline to latest pipeline execution (aws#931)
1 parent 4623bd6 commit ab00bac

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

src/sagemaker/workflow/pipeline.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@ def start(
332332
A `_PipelineExecution` instance, if successful.
333333
"""
334334
if selective_execution_config is not None:
335+
if selective_execution_config.source_pipeline_execution_arn is None:
336+
selective_execution_config.source_pipeline_execution_arn = (
337+
self._get_latest_execution_arn()
338+
)
335339
selective_execution_config = selective_execution_config.to_request()
336340

337341
kwargs = dict(PipelineName=self.name)
@@ -436,6 +440,17 @@ def list_executions(
436440
if key in response
437441
}
438442

443+
def _get_latest_execution_arn(self):
444+
"""Retrieves the latest execution of this pipeline"""
445+
response = self.list_executions(
446+
sort_by="CreationTime",
447+
sort_order="Descending",
448+
max_results=1,
449+
)
450+
if response["PipelineExecutionSummaries"]:
451+
return response["PipelineExecutionSummaries"][0]["PipelineExecutionArn"]
452+
return None
453+
439454

440455
def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
441456
"""Formats start parameter overrides as a list of dicts.

src/sagemaker/workflow/selective_execution_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
class SelectiveExecutionConfig:
2020
"""Selective execution config config for SageMaker pipeline."""
2121

22-
def __init__(self, source_pipeline_execution_arn: str, selected_steps: List[str]):
22+
def __init__(self, selected_steps: List[str], source_pipeline_execution_arn: str = None):
2323
"""Create a SelectiveExecutionConfig
2424
2525
Args:

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,19 @@ def test_pipeline_start(sagemaker_session_mock):
425425
PipelineName="MyPipeline", PipelineParameters=[{"Name": "alpha", "Value": "epsilon"}]
426426
)
427427

428+
429+
def test_pipeline_start_selective_execution(sagemaker_session_mock):
430+
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = {
431+
"PipelineExecutionArn": "my:arn"
432+
}
433+
pipeline = Pipeline(
434+
name="MyPipeline",
435+
parameters=[],
436+
steps=[],
437+
sagemaker_session=sagemaker_session_mock,
438+
)
439+
440+
# Case 1: Happy path
428441
selective_execution_config = SelectiveExecutionConfig(
429442
source_pipeline_execution_arn="foo-arn", selected_steps=["step-1", "step-2", "step-3"]
430443
)
@@ -441,6 +454,37 @@ def test_pipeline_start(sagemaker_session_mock):
441454
},
442455
)
443456

457+
# Case 2: Start selective execution without SourcePipelineExecutionArn
458+
sagemaker_session_mock.sagemaker_client.list_pipeline_executions.return_value = {
459+
"PipelineExecutionSummaries": [
460+
{
461+
"PipelineExecutionArn": "my:latest:execution:arn",
462+
"PipelineExecutionDisplayName": "Latest",
463+
}
464+
]
465+
}
466+
selective_execution_config = SelectiveExecutionConfig(
467+
selected_steps=["step-1", "step-2", "step-3"]
468+
)
469+
pipeline.start(selective_execution_config=selective_execution_config)
470+
sagemaker_session_mock.sagemaker_client.list_pipeline_executions.assert_called_with(
471+
PipelineName="MyPipeline",
472+
SortBy="CreationTime",
473+
SortOrder="Descending",
474+
MaxResults=1,
475+
)
476+
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with(
477+
PipelineName="MyPipeline",
478+
SelectiveExecutionConfig={
479+
"SelectedSteps": [
480+
{"StepName": "step-1"},
481+
{"StepName": "step-2"},
482+
{"StepName": "step-3"},
483+
],
484+
"SourcePipelineExecutionArn": "my:latest:execution:arn",
485+
},
486+
)
487+
444488

445489
def test_pipeline_basic():
446490
parameter = ParameterString("MyStr")

0 commit comments

Comments
 (0)