Skip to content

Commit ccc494e

Browse files
author
Dewen Qi
committed
fix: Allow StepCollection added in ConditionStep to be depended on
1 parent b4f05b8 commit ccc494e

File tree

4 files changed

+243
-69
lines changed

4 files changed

+243
-69
lines changed

src/sagemaker/workflow/pipeline.py

Lines changed: 71 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -37,48 +37,58 @@
3737
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
3838
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
3939
from sagemaker.workflow.properties import Properties
40-
from sagemaker.workflow.steps import Step
40+
from sagemaker.workflow.steps import Step, StepTypeEnum
4141
from sagemaker.workflow.step_collections import StepCollection
4242
from sagemaker.workflow.condition_step import ConditionStep
4343
from sagemaker.workflow.utilities import list_to_request
4444

45+
_DEFAULT_EXPERIMENT_CFG = PipelineExperimentConfig(
46+
ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_EXECUTION_ID
47+
)
48+
4549

46-
@attr.s
4750
class Pipeline(Entity):
48-
"""Pipeline for workflow.
51+
"""Pipeline for workflow."""
4952

50-
Attributes:
51-
name (str): The name of the pipeline.
52-
parameters (Sequence[Parameter]): The list of the parameters.
53-
pipeline_experiment_config (Optional[PipelineExperimentConfig]): If set,
54-
the workflow will attempt to create an experiment and trial before
55-
executing the steps. Creation will be skipped if an experiment or a trial with
56-
the same name already exists. By default, pipeline name is used as
57-
experiment name and execution id is used as the trial name.
58-
If set to None, no experiment or trial will be created automatically.
59-
steps (Sequence[Union[Step, StepCollection]]): The list of the non-conditional steps
60-
associated with the pipeline. Any steps that are within the
61-
`if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a
62-
pipeline. Of particular note, the workflow service rejects any pipeline definitions that
63-
specify a step in the list of steps of a pipeline and that step in the `if_steps` or
64-
`else_steps` of any `ConditionStep`.
65-
sagemaker_session (sagemaker.session.Session): Session object that manages interactions
66-
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
67-
pipeline creates one using the default AWS configuration chain.
68-
"""
53+
def __init__(
54+
self,
55+
name: str = "",
56+
parameters: Optional[Sequence[Parameter]] = None,
57+
pipeline_experiment_config: Optional[PipelineExperimentConfig] = _DEFAULT_EXPERIMENT_CFG,
58+
steps: Optional[Sequence[Union[Step, StepCollection]]] = None,
59+
sagemaker_session: Session = Session(),
60+
):
61+
"""Initialize a Pipeline
6962
70-
name: str = attr.ib(factory=str)
71-
parameters: Sequence[Parameter] = attr.ib(factory=list)
72-
pipeline_experiment_config: Optional[PipelineExperimentConfig] = attr.ib(
73-
default=PipelineExperimentConfig(
74-
ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_EXECUTION_ID
75-
)
76-
)
77-
steps: Sequence[Union[Step, StepCollection]] = attr.ib(factory=list)
78-
sagemaker_session: Session = attr.ib(factory=Session)
63+
Args:
64+
name (str): The name of the pipeline.
65+
parameters (Sequence[Parameter]): The list of the parameters.
66+
pipeline_experiment_config (Optional[PipelineExperimentConfig]): If set,
67+
the workflow will attempt to create an experiment and trial before
68+
executing the steps. Creation will be skipped if an experiment or a trial with
69+
the same name already exists. By default, pipeline name is used as
70+
experiment name and execution id is used as the trial name.
71+
If set to None, no experiment or trial will be created automatically.
72+
steps (Sequence[Union[Step, StepCollection]]): The list of the non-conditional steps
73+
associated with the pipeline. Any steps that are within the
74+
`if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a
75+
pipeline. Of particular note, the workflow service rejects any pipeline definitions
76+
that specify a step in the list of steps of a pipeline and that step in the
77+
`if_steps` or `else_steps` of any `ConditionStep`.
78+
sagemaker_session (sagemaker.session.Session): Session object that manages interactions
79+
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
80+
pipeline creates one using the default AWS configuration chain.
81+
"""
82+
self.name = name
83+
self.parameters = parameters if parameters else []
84+
self.pipeline_experiment_config = pipeline_experiment_config
85+
self.steps = steps if steps else []
86+
self.sagemaker_session = sagemaker_session
7987

80-
_version: str = "2020-12-01"
81-
_metadata: Dict[str, Any] = dict()
88+
self._version = "2020-12-01"
89+
self._metadata = dict()
90+
self._step_map = dict()
91+
_generate_step_map(self.steps, self._step_map)
8292

8393
def to_request(self) -> RequestType:
8494
"""Gets the request structure for workflow service calls."""
@@ -305,23 +315,27 @@ def definition(self) -> str:
305315

306316
return json.dumps(request_dict)
307317

308-
def _interpolate_step_collection_name_in_depends_on(self, step_requests: dict):
318+
def _interpolate_step_collection_name_in_depends_on(self, step_requests: list):
309319
"""Insert step names as per `StepCollection` name in depends_on list
310320
311321
Args:
312-
step_requests (dict): The raw step request dict without any interpolation.
322+
step_requests (list): The list of raw step request dicts without any interpolation.
313323
"""
314-
step_name_map = {s.name: s for s in self.steps}
315324
for step_request in step_requests:
316-
if not step_request.get("DependsOn", None):
317-
continue
318325
depends_on = []
319-
for depend_step_name in step_request["DependsOn"]:
320-
if isinstance(step_name_map[depend_step_name], StepCollection):
321-
depends_on.extend([s.name for s in step_name_map[depend_step_name].steps])
326+
for depend_step_name in step_request.get("DependsOn", []):
327+
if isinstance(self._step_map[depend_step_name], StepCollection):
328+
depends_on.extend([s.name for s in self._step_map[depend_step_name].steps])
322329
else:
323330
depends_on.append(depend_step_name)
324-
step_request["DependsOn"] = depends_on
331+
if depends_on:
332+
step_request["DependsOn"] = depends_on
333+
334+
if step_request["Type"] == StepTypeEnum.CONDITION.value:
335+
sub_step_requests = (
336+
step_request["Arguments"]["IfSteps"] + step_request["Arguments"]["ElseSteps"]
337+
)
338+
self._interpolate_step_collection_name_in_depends_on(sub_step_requests)
325339

326340

327341
def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
@@ -448,6 +462,20 @@ def update_args(args: Dict[str, Any], **kwargs):
448462
args.update({key: value})
449463

450464

465+
def _generate_step_map(
466+
steps: Sequence[Union[Step, StepCollection]], step_map: dict
467+
) -> Dict[str, Any]:
468+
"""Helper method to create a mapping from Step/Step Collection name to itself."""
469+
for step in steps:
470+
if step.name in step_map:
471+
raise ValueError("Pipeline steps cannot have duplicate names.")
472+
step_map[step.name] = step
473+
if isinstance(step, ConditionStep):
474+
_generate_step_map(step.if_steps + step.else_steps, step_map)
475+
if isinstance(step, StepCollection):
476+
_generate_step_map(step.steps, step_map)
477+
478+
451479
@attr.s
452480
class _PipelineExecution:
453481
"""Internal class for encapsulating pipeline execution instances.
@@ -547,22 +575,11 @@ class PipelineGraph:
547575

548576
def __init__(self, steps: Sequence[Union[Step, StepCollection]]):
549577
self.step_map = {}
550-
self._generate_step_map(steps)
578+
_generate_step_map(steps, self.step_map)
551579
self.adjacency_list = self._initialize_adjacency_list()
552580
if self.is_cyclic():
553581
raise ValueError("Cycle detected in pipeline step graph.")
554582

555-
def _generate_step_map(self, steps: Sequence[Union[Step, StepCollection]]):
556-
"""Helper method to create a mapping from Step/Step Collection name to itself."""
557-
for step in steps:
558-
if step.name in self.step_map:
559-
raise ValueError("Pipeline steps cannot have duplicate names.")
560-
self.step_map[step.name] = step
561-
if isinstance(step, ConditionStep):
562-
self._generate_step_map(step.if_steps + step.else_steps)
563-
if isinstance(step, StepCollection):
564-
self._generate_step_map(step.steps)
565-
566583
@classmethod
567584
def from_pipeline(cls, pipeline: Pipeline):
568585
"""Create a PipelineGraph object from the Pipeline object."""

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_large_pipeline_create(sagemaker_session_mock, role_arn):
7878
pipeline = Pipeline(
7979
name="MyPipeline",
8080
parameters=[parameter],
81-
steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000,
81+
steps=_generate_large_pipeline_steps(parameter),
8282
sagemaker_session=sagemaker_session_mock,
8383
)
8484

@@ -132,7 +132,7 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn):
132132
pipeline = Pipeline(
133133
name="MyPipeline",
134134
parameters=[parameter],
135-
steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000,
135+
steps=_generate_large_pipeline_steps(parameter),
136136
sagemaker_session=sagemaker_session_mock,
137137
)
138138

@@ -437,3 +437,10 @@ def test_pipeline_execution_basics(sagemaker_session_mock):
437437
PipelineExecutionArn="my:arn"
438438
)
439439
assert len(steps) == 1
440+
441+
442+
def _generate_large_pipeline_steps(input_data: object):
443+
steps = []
444+
for i in range(2000):
445+
steps.append(CustomStep(name=f"MyStep{i}", input_data=input_data))
446+
return steps

tests/unit/sagemaker/workflow/test_pipeline_graph.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ def role_arn():
4545
def test_pipeline_duplicate_step_name(sagemaker_session_mock):
4646
step1 = CustomStep(name="foo")
4747
step2 = CustomStep(name="foo")
48-
pipeline = Pipeline(
49-
name="MyPipeline", steps=[step1, step2], sagemaker_session=sagemaker_session_mock
50-
)
5148
with pytest.raises(ValueError) as error:
49+
pipeline = Pipeline(
50+
name="MyPipeline", steps=[step1, step2], sagemaker_session=sagemaker_session_mock
51+
)
5252
PipelineGraph.from_pipeline(pipeline)
5353
assert "Pipeline steps cannot have duplicate names." in str(error.value)
5454

@@ -61,25 +61,25 @@ def test_pipeline_duplicate_step_name_in_condition_step(sagemaker_session_mock):
6161
condition_step = ConditionStep(
6262
name="condStep", conditions=[cond], depends_on=[custom_step], if_steps=[custom_step2]
6363
)
64-
pipeline = Pipeline(
65-
name="MyPipeline",
66-
steps=[custom_step, condition_step],
67-
sagemaker_session=sagemaker_session_mock,
68-
)
6964
with pytest.raises(ValueError) as error:
65+
pipeline = Pipeline(
66+
name="MyPipeline",
67+
steps=[custom_step, condition_step],
68+
sagemaker_session=sagemaker_session_mock,
69+
)
7070
PipelineGraph.from_pipeline(pipeline)
7171
assert "Pipeline steps cannot have duplicate names." in str(error.value)
7272

7373

7474
def test_pipeline_duplicate_step_name_in_step_collection(sagemaker_session_mock):
7575
custom_step = CustomStep(name="foo-1")
7676
custom_step_collection = CustomStepCollection(name="foo", depends_on=[custom_step])
77-
pipeline = Pipeline(
78-
name="MyPipeline",
79-
steps=[custom_step, custom_step_collection],
80-
sagemaker_session=sagemaker_session_mock,
81-
)
8277
with pytest.raises(ValueError) as error:
78+
pipeline = Pipeline(
79+
name="MyPipeline",
80+
steps=[custom_step, custom_step_collection],
81+
sagemaker_session=sagemaker_session_mock,
82+
)
8383
PipelineGraph.from_pipeline(pipeline)
8484
assert "Pipeline steps cannot have duplicate names." in str(error.value)
8585

0 commit comments

Comments
 (0)