Skip to content

Commit d83b6fd

Browse files
SheguftaShegufta Ahsan
authored andcommitted
feature: Add SelectiveExecutionConfig (aws#929)
Co-authored-by: Shegufta Ahsan <[email protected]>
1 parent 9f2a192 commit d83b6fd

File tree

5 files changed

+105
-0
lines changed

5 files changed

+105
-0
lines changed

src/sagemaker/local/local_session.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,8 @@ def start_pipeline_execution(self, PipelineName, **kwargs):
505505
"""
506506
if "ParallelismConfiguration" in kwargs:
507507
logger.warning("Parallelism configuration is not supported in local mode.")
508+
if "SelectiveExecutionConfig" in kwargs:
509+
raise ValueError("SelectiveExecutionConfig is not supported in local mode.")
508510
if PipelineName not in LocalSagemakerClient._pipelines:
509511
error_response = {
510512
"Error": {

src/sagemaker/workflow/pipeline.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
4141
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
4242
from sagemaker.workflow.properties import Properties
43+
from sagemaker.workflow.selective_execution_config import SelectiveExecutionConfig
4344
from sagemaker.workflow.steps import Step, StepTypeEnum
4445
from sagemaker.workflow.step_collections import StepCollection
4546
from sagemaker.workflow.condition_step import ConditionStep
@@ -312,6 +313,7 @@ def start(
312313
execution_display_name: str = None,
313314
execution_description: str = None,
314315
parallelism_config: ParallelismConfiguration = None,
316+
selective_execution_config: SelectiveExecutionConfig = None,
315317
):
316318
"""Starts a Pipeline execution in the Workflow service.
317319
@@ -323,16 +325,22 @@ def start(
323325
parallelism_config (Optional[ParallelismConfiguration]): Parallelism configuration
324326
that is applied to each of the executions of the pipeline. It takes precedence
325327
over the parallelism configuration of the parent pipeline.
328+
selective_execution_config (SelectiveExecutionConfig): configuration for
329+
selective step execution
326330
327331
Returns:
328332
A `_PipelineExecution` instance, if successful.
329333
"""
334+
if selective_execution_config is not None:
335+
selective_execution_config = selective_execution_config.to_request()
336+
330337
kwargs = dict(PipelineName=self.name)
331338
update_args(
332339
kwargs,
333340
PipelineExecutionDescription=execution_description,
334341
PipelineExecutionDisplayName=execution_display_name,
335342
ParallelismConfiguration=parallelism_config,
343+
SelectiveExecutionConfig=selective_execution_config,
336344
)
337345
if self.sagemaker_session.local_mode:
338346
update_args(kwargs, PipelineParameters=parameters)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Pipeline Parallelism Configuration"""
14+
from __future__ import absolute_import
15+
from typing import List
16+
from sagemaker.workflow.entities import RequestType
17+
18+
19+
class SelectiveExecutionConfig:
20+
"""Selective execution config config for SageMaker pipeline."""
21+
22+
def __init__(self, source_pipeline_execution_arn: str, selected_steps: List[str]):
23+
"""Create a SelectiveExecutionConfig
24+
25+
Args:
26+
source_pipeline_execution_arn (str): ARN of the previously executed pipeline execution.
27+
The given arn pipeline execution status can be either Failed or Success.
28+
selected_steps (List[str]): List of step names that pipeline users want to run
29+
in new subworkflow-execution. The steps must be connected.
30+
"""
31+
self.source_pipeline_execution_arn = source_pipeline_execution_arn
32+
self.selected_steps = selected_steps
33+
34+
def _build_selected_steps_from_list(self) -> RequestType:
35+
"""Get the request structure for list of selected steps"""
36+
selected_step_list = []
37+
for selected_step in self.selected_steps:
38+
selected_step_list.append(dict(StepName=selected_step))
39+
return selected_step_list
40+
41+
def to_request(self) -> RequestType:
42+
"""Convert SelectiveExecutionConfig object to request dict."""
43+
request = {}
44+
45+
if self.source_pipeline_execution_arn is not None:
46+
request["SourcePipelineExecutionArn"] = self.source_pipeline_execution_arn
47+
48+
if self.selected_steps is not None:
49+
request["SelectedSteps"] = self._build_selected_steps_from_list()
50+
51+
return request

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep
3636
from sagemaker.local.local_session import LocalSession
3737
from botocore.exceptions import ClientError
38+
from sagemaker.workflow.selective_execution_config import SelectiveExecutionConfig
3839

3940

4041
@pytest.fixture
@@ -424,6 +425,22 @@ def test_pipeline_start(sagemaker_session_mock):
424425
PipelineName="MyPipeline", PipelineParameters=[{"Name": "alpha", "Value": "epsilon"}]
425426
)
426427

428+
selective_execution_config = SelectiveExecutionConfig(
429+
source_pipeline_execution_arn="foo-arn", selected_steps=["step-1", "step-2", "step-3"]
430+
)
431+
pipeline.start(selective_execution_config=selective_execution_config)
432+
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with(
433+
PipelineName="MyPipeline",
434+
SelectiveExecutionConfig={
435+
"SelectedSteps": [
436+
{"StepName": "step-1"},
437+
{"StepName": "step-2"},
438+
{"StepName": "step-3"},
439+
],
440+
"SourcePipelineExecutionArn": "foo-arn",
441+
},
442+
)
443+
427444

428445
def test_pipeline_basic():
429446
parameter = ParameterString("MyStr")
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Unit tests for SelectiveExecutionConfig module"""
14+
15+
from __future__ import absolute_import
16+
17+
from sagemaker.workflow.selective_execution_config import SelectiveExecutionConfig
18+
19+
20+
def test_SelectiveExecutionConfig():
21+
selective_execution_config = SelectiveExecutionConfig(
22+
source_pipeline_execution_arn="foo-arn", selected_steps=["step-1", "step-2", "step-3"]
23+
)
24+
assert selective_execution_config.to_request() == {
25+
"SelectedSteps": [{"StepName": "step-1"}, {"StepName": "step-2"}, {"StepName": "step-3"}],
26+
"SourcePipelineExecutionArn": "foo-arn",
27+
}

0 commit comments

Comments
 (0)