Skip to content

Commit d573e67

Browse files
nmadanNamrata Madan
andauthored
feature: Add helper method to generate pipeline adjacency list (#3128)
Co-authored-by: Namrata Madan <[email protected]>
1 parent f06dab3 commit d573e67

35 files changed

+925
-147
lines changed

src/sagemaker/workflow/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def __init__(
369369
self.kwargs = kwargs
370370
self.container_def_list = container_def_list
371371

372-
self._properties = Properties(path=f"Steps.{name}", shape_name="DescribeModelPackageOutput")
372+
self._properties = Properties(step_name=name, shape_name="DescribeModelPackageOutput")
373373

374374
@property
375375
def arguments(self) -> RequestType:

src/sagemaker/workflow/callback_step.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,12 @@ def __init__(
112112
self.cache_config = cache_config
113113
self.inputs = inputs
114114

115-
root_path = f"Steps.{name}"
116-
root_prop = Properties(path=root_path)
115+
root_prop = Properties(step_name=name)
117116

118117
property_dict = {}
119118
for output in outputs:
120119
property_dict[output.output_name] = Properties(
121-
f"{root_path}.OutputParameters['{output.output_name}']"
120+
step_name=name, path=f"OutputParameters['{output.output_name}']"
122121
)
123122

124123
root_prop.__dict__["Outputs"] = property_dict

src/sagemaker/workflow/clarify_check_step.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,13 +236,12 @@ def __init__(
236236
self._generate_processing_job_analysis_config(), self._baselining_processor
237237
)
238238

239-
root_path = f"Steps.{name}"
240-
root_prop = Properties(path=root_path)
239+
root_prop = Properties(step_name=name)
241240
root_prop.__dict__["CalculatedBaselineConstraints"] = Properties(
242-
f"{root_path}.CalculatedBaselineConstraints"
241+
step_name=name, path="CalculatedBaselineConstraints"
243242
)
244243
root_prop.__dict__["BaselineUsedForDriftCheckConstraints"] = Properties(
245-
f"{root_path}.BaselineUsedForDriftCheckConstraints"
244+
step_name=name, path="BaselineUsedForDriftCheckConstraints"
246245
)
247246
self._properties = root_prop
248247

src/sagemaker/workflow/condition_step.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,8 @@ def __init__(
7777
self.if_steps = if_steps or []
7878
self.else_steps = else_steps or []
7979

80-
root_path = f"Steps.{name}"
81-
root_prop = Properties(path=root_path)
82-
root_prop.__dict__["Outcome"] = Properties(f"{root_path}.Outcome")
80+
root_prop = Properties(step_name=name)
81+
root_prop.__dict__["Outcome"] = Properties(step_name=name, path="Outcome")
8382
self._properties = root_prop
8483

8584
@property
@@ -91,6 +90,11 @@ def arguments(self) -> RequestType:
9190
ElseSteps=list_to_request(self.else_steps),
9291
)
9392

93+
@property
94+
def step_only_arguments(self):
95+
"""Argument dict pertaining to the step only, and not the `if_steps` or `else_steps`."""
96+
return self.conditions
97+
9498
@property
9599
def properties(self):
96100
"""A simple Properties object with `Outcome` as the only property"""
@@ -126,5 +130,10 @@ def expr(self):
126130
}
127131
}
128132

133+
@property
134+
def _referenced_steps(self) -> List[str]:
135+
"""List of step names that this function depends on."""
136+
return [self.step.name]
137+
129138

130139
JsonGet = deprecated_class(JsonGet, "JsonGet")

src/sagemaker/workflow/conditions.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
"""
1818
from __future__ import absolute_import
1919

20+
import abc
21+
2022
from enum import Enum
2123
from typing import Dict, List, Union
2224

@@ -33,6 +35,7 @@
3335
from sagemaker.workflow.execution_variables import ExecutionVariable
3436
from sagemaker.workflow.parameters import Parameter
3537
from sagemaker.workflow.properties import Properties
38+
from sagemaker.workflow.entities import PipelineVariable
3639

3740
# TODO: consider base class for those with an expr method, rather than defining a type here
3841
ConditionValueType = Union[ExecutionVariable, Parameter, Properties]
@@ -61,6 +64,11 @@ class Condition(Entity):
6164

6265
condition_type: ConditionTypeEnum = attr.ib(factory=ConditionTypeEnum.factory)
6366

67+
@property
68+
@abc.abstractmethod
69+
def _referenced_steps(self) -> List[str]:
70+
"""List of step names that this function depends on."""
71+
6472

6573
@attr.s
6674
class ConditionComparison(Condition):
@@ -84,6 +92,16 @@ def to_request(self) -> RequestType:
8492
"RightValue": primitive_or_expr(self.right),
8593
}
8694

95+
@property
96+
def _referenced_steps(self) -> List[str]:
97+
"""List of step names that this function depends on."""
98+
steps = []
99+
if isinstance(self.left, PipelineVariable):
100+
steps.extend(self.left._referenced_steps)
101+
if isinstance(self.right, PipelineVariable):
102+
steps.extend(self.right._referenced_steps)
103+
return steps
104+
87105

88106
class ConditionEquals(ConditionComparison):
89107
"""A condition for equality comparisons."""
@@ -213,6 +231,17 @@ def to_request(self) -> RequestType:
213231
"Values": [primitive_or_expr(in_value) for in_value in self.in_values],
214232
}
215233

234+
@property
235+
def _referenced_steps(self) -> List[str]:
236+
"""List of step names that this function depends on."""
237+
steps = []
238+
if isinstance(self.value, PipelineVariable):
239+
steps.extend(self.value._referenced_steps)
240+
for in_value in self.in_values:
241+
if isinstance(in_value, PipelineVariable):
242+
steps.extend(in_value._referenced_steps)
243+
return steps
244+
216245

217246
class ConditionNot(Condition):
218247
"""A condition for negating another `Condition`."""
@@ -230,6 +259,11 @@ def to_request(self) -> RequestType:
230259
"""Get the request structure for workflow service calls."""
231260
return {"Type": self.condition_type.value, "Expression": self.expression.to_request()}
232261

262+
@property
263+
def _referenced_steps(self) -> List[str]:
264+
"""List of step names that this function depends on."""
265+
return self.expression._referenced_steps
266+
233267

234268
class ConditionOr(Condition):
235269
"""A condition for taking the logical OR of a list of `Condition` instances."""
@@ -250,6 +284,14 @@ def to_request(self) -> RequestType:
250284
"Conditions": [condition.to_request() for condition in self.conditions],
251285
}
252286

287+
@property
288+
def _referenced_steps(self) -> List[str]:
289+
"""List of step names that this function depends on."""
290+
steps = []
291+
for condition in self.conditions:
292+
steps.extend(condition._referenced_steps)
293+
return steps
294+
253295

254296
def primitive_or_expr(
255297
value: Union[ExecutionVariable, Expression, PrimitiveType, Parameter, Properties]

src/sagemaker/workflow/emr_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
self.args = emr_step_args
9595
self.cache_config = cache_config
9696

97-
root_property = Properties(path=f"Steps.{name}", shape_name="Step", service_name="emr")
97+
root_property = Properties(step_name=name, shape_name="Step", service_name="emr")
9898
root_property.__dict__["ClusterId"] = cluster_id
9999
self._properties = root_property
100100

src/sagemaker/workflow/entities.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,8 @@ def to_string(self):
102102
@abc.abstractmethod
103103
def expr(self) -> RequestType:
104104
"""Get the expression structure for workflow service calls."""
105+
106+
@property
107+
@abc.abstractmethod
108+
def _referenced_steps(self) -> List[str]:
109+
"""List of step names that this function depends on."""

src/sagemaker/workflow/execution_variables.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Pipeline parameters and conditions for workflow."""
1414
from __future__ import absolute_import
1515

16+
from typing import List
1617
from sagemaker.workflow.entities import (
1718
RequestType,
1819
PipelineVariable,
@@ -42,6 +43,11 @@ def expr(self) -> RequestType:
4243
"""The 'Get' expression dict for an `ExecutionVariable`."""
4344
return {"Get": f"Execution.{self.name}"}
4445

46+
@property
47+
def _referenced_steps(self) -> List[str]:
48+
"""List of step names that this function depends on."""
49+
return []
50+
4551

4652
class ExecutionVariables:
4753
"""Provide access to all available execution variables:

src/sagemaker/workflow/functions.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ def expr(self):
6464
},
6565
}
6666

67+
@property
68+
def _referenced_steps(self) -> List[str]:
69+
"""List of step names that this function depends on."""
70+
steps = []
71+
for value in self.values:
72+
if isinstance(value, PipelineVariable):
73+
steps.extend(value._referenced_steps)
74+
return steps
75+
6776

6877
@attr.s
6978
class JsonGet(PipelineVariable):
@@ -96,3 +105,8 @@ def expr(self):
96105
"Path": self.json_path,
97106
}
98107
}
108+
109+
@property
110+
def _referenced_steps(self) -> List[str]:
111+
"""List of step names that this function depends on."""
112+
return [self.step_name]

src/sagemaker/workflow/lambda_step.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,12 @@ def __init__(
115115
self.cache_config = cache_config
116116
self.inputs = inputs if inputs is not None else {}
117117

118-
root_path = f"Steps.{name}"
119-
root_prop = Properties(path=root_path)
118+
root_prop = Properties(step_name=name)
120119

121120
property_dict = {}
122121
for output in self.outputs:
123122
property_dict[output.output_name] = Properties(
124-
f"{root_path}.OutputParameters['{output.output_name}']"
123+
step_name=name, path=f"OutputParameters['{output.output_name}']"
125124
)
126125

127126
root_prop.__dict__["Outputs"] = property_dict

src/sagemaker/workflow/parameters.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def expr(self) -> Dict[str, str]:
9090
"""The 'Get' expression dict for a `Parameter`."""
9191
return Parameter._expr(self.name)
9292

93+
@property
94+
def _referenced_steps(self) -> List[str]:
95+
"""List of step names that this function depends on."""
96+
return []
97+
9398
@classmethod
9499
def _expr(cls, name):
95100
"""An internal classmethod for the 'Get' expression dict for a `Parameter`.

src/sagemaker/workflow/pipeline.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from sagemaker.workflow.properties import Properties
4040
from sagemaker.workflow.steps import Step
4141
from sagemaker.workflow.step_collections import StepCollection
42+
from sagemaker.workflow.condition_step import ConditionStep
4243
from sagemaker.workflow.utilities import list_to_request
4344

4445

@@ -534,3 +535,113 @@ def wait(self, delay=30, max_attempts=60):
534535
waiter_id, model, self.sagemaker_session.sagemaker_client
535536
)
536537
waiter.wait(PipelineExecutionArn=self.arn)
538+
539+
540+
class PipelineGraph:
541+
"""Helper class representing the Pipeline Directed Acyclic Graph (DAG)
542+
543+
Attributes:
544+
steps (Sequence[Union[Step, StepCollection]]): Sequence of `Step`s and/or `StepCollection`s
545+
that represent each node in the pipeline DAG
546+
"""
547+
548+
def __init__(self, steps: Sequence[Union[Step, StepCollection]]):
549+
self.step_map = {}
550+
self._generate_step_map(steps)
551+
self.adjacency_list = self._initialize_adjacency_list()
552+
if self.is_cyclic():
553+
raise ValueError("Cycle detected in pipeline step graph.")
554+
555+
def _generate_step_map(self, steps: Sequence[Union[Step, StepCollection]]):
556+
"""Helper method to create a mapping from Step/Step Collection name to itself."""
557+
for step in steps:
558+
if step.name in self.step_map:
559+
raise ValueError("Pipeline steps cannot have duplicate names.")
560+
self.step_map[step.name] = step
561+
if isinstance(step, ConditionStep):
562+
self._generate_step_map(step.if_steps + step.else_steps)
563+
if isinstance(step, StepCollection):
564+
self._generate_step_map(step.steps)
565+
566+
@classmethod
567+
def from_pipeline(cls, pipeline: Pipeline):
568+
"""Create a PipelineGraph object from the Pipeline object."""
569+
return cls(pipeline.steps)
570+
571+
def _initialize_adjacency_list(self) -> Dict[str, List[str]]:
572+
"""Generate an adjacency list representing the step dependency DAG in this pipeline."""
573+
from collections import defaultdict
574+
575+
dependency_list = defaultdict(set)
576+
for step in self.step_map.values():
577+
if isinstance(step, Step):
578+
dependency_list[step.name].update(step._find_step_dependencies(self.step_map))
579+
580+
if isinstance(step, ConditionStep):
581+
for child_step in step.if_steps + step.else_steps:
582+
if isinstance(child_step, Step):
583+
dependency_list[child_step.name].add(step.name)
584+
elif isinstance(child_step, StepCollection):
585+
child_first_step = self.step_map[child_step.name].steps[0].name
586+
dependency_list[child_first_step].add(step.name)
587+
588+
adjacency_list = {}
589+
for step in dependency_list:
590+
for step_dependency in dependency_list[step]:
591+
adjacency_list[step_dependency] = list(
592+
set(adjacency_list.get(step_dependency, []) + [step])
593+
)
594+
for step in dependency_list:
595+
if step not in adjacency_list:
596+
adjacency_list[step] = []
597+
return adjacency_list
598+
599+
def is_cyclic(self) -> bool:
600+
"""Check if this pipeline graph is cyclic.
601+
602+
Returns true if it is cyclic, false otherwise.
603+
"""
604+
605+
def is_cyclic_helper(current_step):
606+
visited_steps.add(current_step)
607+
recurse_steps.add(current_step)
608+
for child_step in self.adjacency_list[current_step]:
609+
if child_step in recurse_steps:
610+
return True
611+
if child_step not in visited_steps:
612+
if is_cyclic_helper(child_step):
613+
return True
614+
recurse_steps.remove(current_step)
615+
return False
616+
617+
visited_steps = set()
618+
recurse_steps = set()
619+
for step in self.adjacency_list:
620+
if step not in visited_steps:
621+
if is_cyclic_helper(step):
622+
return True
623+
return False
624+
625+
def __iter__(self):
626+
"""Perform topological sort traversal of the Pipeline Graph."""
627+
628+
def topological_sort(current_step):
629+
visited_steps.add(current_step)
630+
for child_step in self.adjacency_list[current_step]:
631+
if child_step not in visited_steps:
632+
topological_sort(child_step)
633+
self.stack.append(current_step)
634+
635+
visited_steps = set()
636+
self.stack = [] # pylint: disable=W0201
637+
for step in self.adjacency_list:
638+
if step not in visited_steps:
639+
topological_sort(step)
640+
return self
641+
642+
def __next__(self) -> Step:
643+
"""Return the next Step node from the Topological sort order."""
644+
645+
while self.stack:
646+
return self.step_map.get(self.stack.pop())
647+
raise StopIteration

0 commit comments

Comments
 (0)