Skip to content

Commit c2548b2

Browse files
SheguftaShegufta Ahsannishkris
authored
feat: Selective Step Execution feature in Pipelines (#3898)
Co-authored-by: Shegufta Ahsan <[email protected]> Co-authored-by: Nishant Krishnamoorthy <[email protected]>
1 parent 9f2a192 commit c2548b2

File tree

6 files changed

+240
-0
lines changed

6 files changed

+240
-0
lines changed

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@ Pipeline Experiment Config
110110

111111
.. autoclass:: sagemaker.workflow.pipeline_experiment_config.PipelineExperimentConfigProperty
112112

113+
Selective Execution Config
114+
--------------------------
115+
116+
.. autoclass:: sagemaker.workflow.selective_execution_config.SelectiveExecutionConfig
117+
118+
113119
Properties
114120
----------
115121

src/sagemaker/local/local_session.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,8 @@ def start_pipeline_execution(self, PipelineName, **kwargs):
505505
"""
506506
if "ParallelismConfiguration" in kwargs:
507507
logger.warning("Parallelism configuration is not supported in local mode.")
508+
if "SelectiveExecutionConfig" in kwargs:
509+
raise ValueError("SelectiveExecutionConfig is not supported in local mode.")
508510
if PipelineName not in LocalSagemakerClient._pipelines:
509511
error_response = {
510512
"Error": {

src/sagemaker/workflow/pipeline.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
4141
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
4242
from sagemaker.workflow.properties import Properties
43+
from sagemaker.workflow.selective_execution_config import SelectiveExecutionConfig
4344
from sagemaker.workflow.steps import Step, StepTypeEnum
4445
from sagemaker.workflow.step_collections import StepCollection
4546
from sagemaker.workflow.condition_step import ConditionStep
@@ -312,6 +313,7 @@ def start(
312313
execution_display_name: str = None,
313314
execution_description: str = None,
314315
parallelism_config: ParallelismConfiguration = None,
316+
selective_execution_config: SelectiveExecutionConfig = None,
315317
):
316318
"""Starts a Pipeline execution in the Workflow service.
317319
@@ -323,16 +325,26 @@ def start(
323325
parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
324326
that is applied to each of the executions of the pipeline. It takes precedence
325327
over the parallelism configuration of the parent pipeline.
328+
selective_execution_config (Optional[SelectiveExecutionConfig]): The configuration for
329+
selective step execution.
326330
327331
Returns:
328332
A `_PipelineExecution` instance, if successful.
329333
"""
334+
if selective_execution_config is not None:
335+
if selective_execution_config.source_pipeline_execution_arn is None:
336+
selective_execution_config.source_pipeline_execution_arn = (
337+
self._get_latest_execution_arn()
338+
)
339+
selective_execution_config = selective_execution_config.to_request()
340+
330341
kwargs = dict(PipelineName=self.name)
331342
update_args(
332343
kwargs,
333344
PipelineExecutionDescription=execution_description,
334345
PipelineExecutionDisplayName=execution_display_name,
335346
ParallelismConfiguration=parallelism_config,
347+
SelectiveExecutionConfig=selective_execution_config,
336348
)
337349
if self.sagemaker_session.local_mode:
338350
update_args(kwargs, PipelineParameters=parameters)
@@ -388,6 +400,57 @@ def _interpolate_step_collection_name_in_depends_on(self, step_requests: list):
388400
)
389401
self._interpolate_step_collection_name_in_depends_on(sub_step_requests)
390402

403+
def list_executions(
404+
self,
405+
sort_by: str = None,
406+
sort_order: str = None,
407+
max_results: int = None,
408+
next_token: str = None,
409+
) -> Dict[str, Any]:
410+
"""Lists a pipeline's executions.
411+
412+
Args:
413+
sort_by (str): The field by which to sort results(CreationTime/PipelineExecutionArn).
414+
sort_order (str): The sort order for results (Ascending/Descending).
415+
max_results (int): The maximum number of pipeline executions to return in the response.
416+
next_token (str): If the result of the previous ListPipelineExecutions request was
417+
truncated, the response includes a NextToken. To retrieve the next set of pipeline
418+
executions, use the token in the next request.
419+
420+
Returns:
421+
List of Pipeline Execution Summaries. See
422+
boto3 client list_pipeline_executions
423+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_pipeline_executions
424+
"""
425+
kwargs = dict(PipelineName=self.name)
426+
update_args(
427+
kwargs,
428+
SortBy=sort_by,
429+
SortOrder=sort_order,
430+
NextToken=next_token,
431+
MaxResults=max_results,
432+
)
433+
response = self.sagemaker_session.sagemaker_client.list_pipeline_executions(**kwargs)
434+
435+
# Return only PipelineExecutionSummaries and NextToken from the list_pipeline_executions
436+
# response
437+
return {
438+
key: response[key]
439+
for key in ["PipelineExecutionSummaries", "NextToken"]
440+
if key in response
441+
}
442+
443+
def _get_latest_execution_arn(self):
444+
"""Retrieves the latest execution of this pipeline"""
445+
response = self.list_executions(
446+
sort_by="CreationTime",
447+
sort_order="Descending",
448+
max_results=1,
449+
)
450+
if response["PipelineExecutionSummaries"]:
451+
return response["PipelineExecutionSummaries"][0]["PipelineExecutionArn"]
452+
return None
453+
391454

392455
def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
393456
"""Formats start parameter overrides as a list of dicts.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Pipeline Parallelism Configuration"""
14+
from __future__ import absolute_import
15+
from typing import List
16+
from sagemaker.workflow.entities import RequestType
17+
18+
19+
class SelectiveExecutionConfig:
20+
"""The selective execution configuration, which defines a subset of pipeline steps to run in
21+
22+
another SageMaker pipeline run.
23+
"""
24+
25+
def __init__(self, selected_steps: List[str], source_pipeline_execution_arn: str = None):
26+
"""Create a `SelectiveExecutionConfig`.
27+
28+
Args:
29+
source_pipeline_execution_arn (str): The ARN from a reference execution of the
30+
current pipeline. Used to copy input collaterals needed for the selected
31+
steps to run. The execution status of the pipeline can be `Stopped`, `Failed`, or
32+
`Succeeded`.
33+
selected_steps (List[str]): A list of pipeline steps to run. All step(s) in all
34+
path(s) between two selected steps should be included.
35+
"""
36+
self.source_pipeline_execution_arn = source_pipeline_execution_arn
37+
self.selected_steps = selected_steps
38+
39+
def _build_selected_steps_from_list(self) -> RequestType:
40+
"""Get the request structure for the list of selected steps."""
41+
selected_step_list = []
42+
for selected_step in self.selected_steps:
43+
selected_step_list.append(dict(StepName=selected_step))
44+
return selected_step_list
45+
46+
def to_request(self) -> RequestType:
47+
"""Convert `SelectiveExecutionConfig` object to request dict."""
48+
request = {}
49+
50+
if self.source_pipeline_execution_arn is not None:
51+
request["SourcePipelineExecutionArn"] = self.source_pipeline_execution_arn
52+
53+
if self.selected_steps is not None:
54+
request["SelectedSteps"] = self._build_selected_steps_from_list()
55+
56+
return request

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep
3636
from sagemaker.local.local_session import LocalSession
3737
from botocore.exceptions import ClientError
38+
from sagemaker.workflow.selective_execution_config import SelectiveExecutionConfig
3839

3940

4041
@pytest.fixture
@@ -425,6 +426,66 @@ def test_pipeline_start(sagemaker_session_mock):
425426
)
426427

427428

429+
def test_pipeline_start_selective_execution(sagemaker_session_mock):
430+
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = {
431+
"PipelineExecutionArn": "my:arn"
432+
}
433+
pipeline = Pipeline(
434+
name="MyPipeline",
435+
parameters=[],
436+
steps=[],
437+
sagemaker_session=sagemaker_session_mock,
438+
)
439+
440+
# Case 1: Happy path
441+
selective_execution_config = SelectiveExecutionConfig(
442+
source_pipeline_execution_arn="foo-arn", selected_steps=["step-1", "step-2", "step-3"]
443+
)
444+
pipeline.start(selective_execution_config=selective_execution_config)
445+
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with(
446+
PipelineName="MyPipeline",
447+
SelectiveExecutionConfig={
448+
"SelectedSteps": [
449+
{"StepName": "step-1"},
450+
{"StepName": "step-2"},
451+
{"StepName": "step-3"},
452+
],
453+
"SourcePipelineExecutionArn": "foo-arn",
454+
},
455+
)
456+
457+
# Case 2: Start selective execution without SourcePipelineExecutionArn
458+
sagemaker_session_mock.sagemaker_client.list_pipeline_executions.return_value = {
459+
"PipelineExecutionSummaries": [
460+
{
461+
"PipelineExecutionArn": "my:latest:execution:arn",
462+
"PipelineExecutionDisplayName": "Latest",
463+
}
464+
]
465+
}
466+
selective_execution_config = SelectiveExecutionConfig(
467+
selected_steps=["step-1", "step-2", "step-3"]
468+
)
469+
pipeline.start(selective_execution_config=selective_execution_config)
470+
sagemaker_session_mock.sagemaker_client.list_pipeline_executions.assert_called_with(
471+
PipelineName="MyPipeline",
472+
SortBy="CreationTime",
473+
SortOrder="Descending",
474+
MaxResults=1,
475+
)
476+
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with(
477+
PipelineName="MyPipeline",
478+
SelectiveExecutionConfig={
479+
"SelectedSteps": [
480+
{"StepName": "step-1"},
481+
{"StepName": "step-2"},
482+
{"StepName": "step-3"},
483+
],
484+
"SourcePipelineExecutionArn": "my:latest:execution:arn",
485+
},
486+
)
487+
488+
428489
def test_pipeline_basic():
429490
parameter = ParameterString("MyStr")
430491
pipeline = Pipeline(
@@ -593,6 +654,31 @@ def test_pipeline_disable_experiment_config():
593654
)
594655

595656

657+
def test_pipeline_list_executions(sagemaker_session_mock):
658+
sagemaker_session_mock.sagemaker_client.list_pipeline_executions.return_value = {
659+
"PipelineExecutionSummaries": [Mock()],
660+
"ResponseMetadata": "metadata",
661+
}
662+
pipeline = Pipeline(
663+
name="MyPipeline",
664+
parameters=[ParameterString("alpha", "beta"), ParameterString("gamma", "delta")],
665+
steps=[],
666+
sagemaker_session=sagemaker_session_mock,
667+
)
668+
executions = pipeline.list_executions()
669+
assert len(executions) == 1
670+
assert len(executions["PipelineExecutionSummaries"]) == 1
671+
sagemaker_session_mock.sagemaker_client.list_pipeline_executions.return_value = {
672+
"PipelineExecutionSummaries": [Mock(), Mock()],
673+
"NextToken": "token",
674+
"ResponseMetadata": "metadata",
675+
}
676+
executions = pipeline.list_executions()
677+
assert len(executions) == 2
678+
assert len(executions["PipelineExecutionSummaries"]) == 2
679+
assert executions["NextToken"] == "token"
680+
681+
596682
def test_pipeline_execution_basics(sagemaker_session_mock):
597683
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = {
598684
"PipelineExecutionArn": "my:arn"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Unit tests for SelectiveExecutionConfig module"""
14+
15+
from __future__ import absolute_import
16+
17+
from sagemaker.workflow.selective_execution_config import SelectiveExecutionConfig
18+
19+
20+
def test_SelectiveExecutionConfig():
21+
selective_execution_config = SelectiveExecutionConfig(
22+
source_pipeline_execution_arn="foo-arn", selected_steps=["step-1", "step-2", "step-3"]
23+
)
24+
assert selective_execution_config.to_request() == {
25+
"SelectedSteps": [{"StepName": "step-1"}, {"StepName": "step-2"}, {"StepName": "step-3"}],
26+
"SourcePipelineExecutionArn": "foo-arn",
27+
}

0 commit comments

Comments
 (0)