-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: add helper method to generate pipeline adjacency list #3128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,7 @@ | |
from sagemaker.workflow.properties import Properties | ||
from sagemaker.workflow.steps import Step | ||
from sagemaker.workflow.step_collections import StepCollection | ||
from sagemaker.workflow.condition_step import ConditionStep | ||
from sagemaker.workflow.utilities import list_to_request | ||
|
||
|
||
|
@@ -534,3 +535,113 @@ def wait(self, delay=30, max_attempts=60): | |
waiter_id, model, self.sagemaker_session.sagemaker_client | ||
) | ||
waiter.wait(PipelineExecutionArn=self.arn) | ||
|
||
|
||
class PipelineGraph: | ||
"""Helper class representing the Pipeline Directed Acyclic Graph (DAG) | ||
|
||
Attributes: | ||
steps (Sequence[Union[Step, StepCollection]]): Sequence of `Step`s and/or `StepCollection`s | ||
that represent each node in the pipeline DAG | ||
""" | ||
|
||
def __init__(self, steps: Sequence[Union[Step, StepCollection]]): | ||
self.step_map = {} | ||
self._generate_step_map(steps) | ||
self.adjacency_list = self._initialize_adjacency_list() | ||
if self.is_cyclic(): | ||
raise ValueError("Cycle detected in pipeline step graph.") | ||
|
||
def _generate_step_map(self, steps: Sequence[Union[Step, StepCollection]]): | ||
"""Helper method to create a mapping from Step/Step Collection name to itself.""" | ||
for step in steps: | ||
if step.name in self.step_map: | ||
raise ValueError("Pipeline steps cannot have duplicate names.") | ||
self.step_map[step.name] = step | ||
if isinstance(step, ConditionStep): | ||
self._generate_step_map(step.if_steps + step.else_steps) | ||
if isinstance(step, StepCollection): | ||
self._generate_step_map(step.steps) | ||
|
||
@classmethod | ||
def from_pipeline(cls, pipeline: Pipeline): | ||
"""Create a PipelineGraph object from the Pipeline object.""" | ||
return cls(pipeline.steps) | ||
|
||
def _initialize_adjacency_list(self) -> Dict[str, List[str]]: | ||
"""Generate an adjacency list representing the step dependency DAG in this pipeline.""" | ||
from collections import defaultdict | ||
|
||
dependency_list = defaultdict(set) | ||
for step in self.step_map.values(): | ||
if isinstance(step, Step): | ||
dependency_list[step.name].update(step._find_step_dependencies(self.step_map)) | ||
|
||
if isinstance(step, ConditionStep): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The if else branches in the condition step are already flattened out. So it just need to be treated as a single node. The condition step itself has condition expressions, which may contain property references. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment can be ignored. |
||
for child_step in step.if_steps + step.else_steps: | ||
if isinstance(child_step, Step): | ||
dependency_list[child_step.name].add(step.name) | ||
elif isinstance(child_step, StepCollection): | ||
child_first_step = self.step_map[child_step.name].steps[0].name | ||
dependency_list[child_first_step].add(step.name) | ||
|
||
adjacency_list = {} | ||
for step in dependency_list: | ||
for step_dependency in dependency_list[step]: | ||
mufaddal-rohawala marked this conversation as resolved.
Show resolved
Hide resolved
|
||
adjacency_list[step_dependency] = list( | ||
set(adjacency_list.get(step_dependency, []) + [step]) | ||
) | ||
for step in dependency_list: | ||
if step not in adjacency_list: | ||
adjacency_list[step] = [] | ||
return adjacency_list | ||
|
||
def is_cyclic(self) -> bool: | ||
"""Check if this pipeline graph is cyclic. | ||
|
||
Returns true if it is cyclic, false otherwise. | ||
""" | ||
|
||
def is_cyclic_helper(current_step): | ||
visited_steps.add(current_step) | ||
recurse_steps.add(current_step) | ||
for child_step in self.adjacency_list[current_step]: | ||
if child_step in recurse_steps: | ||
return True | ||
if child_step not in visited_steps: | ||
if is_cyclic_helper(child_step): | ||
return True | ||
recurse_steps.remove(current_step) | ||
return False | ||
|
||
visited_steps = set() | ||
recurse_steps = set() | ||
for step in self.adjacency_list: | ||
if step not in visited_steps: | ||
if is_cyclic_helper(step): | ||
return True | ||
return False | ||
|
||
def __iter__(self): | ||
"""Perform topological sort traversal of the Pipeline Graph.""" | ||
|
||
def topological_sort(current_step): | ||
visited_steps.add(current_step) | ||
for child_step in self.adjacency_list[current_step]: | ||
if child_step not in visited_steps: | ||
topological_sort(child_step) | ||
self.stack.append(current_step) | ||
|
||
visited_steps = set() | ||
self.stack = [] # pylint: disable=W0201 | ||
for step in self.adjacency_list: | ||
if step not in visited_steps: | ||
topological_sort(step) | ||
return self | ||
|
||
def __next__(self) -> Step: | ||
"""Return the next Step node from the Topological sort order.""" | ||
|
||
while self.stack: | ||
return self.step_map.get(self.stack.pop()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji. You are using the |
||
raise StopIteration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.
Modifying
object.__dict__
directly or writing to an instance of a class__dict__
attribute directly is not recommended. Inside every module is a__dict__
object.dict attribute which contains its symbol table. If you modifyobject.__dict__
, then the symbol table is changed. Also, direct assignment to the__dict__
attribute is not possible.Learn more