Skip to content

Commit 608a09c

Browse files
author
Namrata Madan
committed
feature: add helper method to generate pipeline adjacency list
1 parent f5a9e28 commit 608a09c

34 files changed

+609
-112
lines changed

src/sagemaker/workflow/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def __init__(
369369
self.kwargs = kwargs
370370
self.container_def_list = container_def_list
371371

372-
self._properties = Properties(path=f"Steps.{name}", shape_name="DescribeModelPackageOutput")
372+
self._properties = Properties(step_name=name, shape_name="DescribeModelPackageOutput")
373373

374374
@property
375375
def arguments(self) -> RequestType:

src/sagemaker/workflow/callback_step.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,12 @@ def __init__(
112112
self.cache_config = cache_config
113113
self.inputs = inputs
114114

115-
root_path = f"Steps.{name}"
116-
root_prop = Properties(path=root_path)
115+
root_prop = Properties(step_name=name)
117116

118117
property_dict = {}
119118
for output in outputs:
120119
property_dict[output.output_name] = Properties(
121-
f"{root_path}.OutputParameters['{output.output_name}']"
120+
step_name=name, path=f"OutputParameters['{output.output_name}']"
122121
)
123122

124123
root_prop.__dict__["Outputs"] = property_dict

src/sagemaker/workflow/clarify_check_step.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,13 +236,12 @@ def __init__(
236236
self._generate_processing_job_analysis_config(), self._baselining_processor
237237
)
238238

239-
root_path = f"Steps.{name}"
240-
root_prop = Properties(path=root_path)
239+
root_prop = Properties(step_name=name)
241240
root_prop.__dict__["CalculatedBaselineConstraints"] = Properties(
242-
f"{root_path}.CalculatedBaselineConstraints"
241+
step_name=name, path="CalculatedBaselineConstraints"
243242
)
244243
root_prop.__dict__["BaselineUsedForDriftCheckConstraints"] = Properties(
245-
f"{root_path}.BaselineUsedForDriftCheckConstraints"
244+
step_name=name, path="BaselineUsedForDriftCheckConstraints"
246245
)
247246
self._properties = root_prop
248247

src/sagemaker/workflow/condition_step.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,8 @@ def __init__(
7777
self.if_steps = if_steps or []
7878
self.else_steps = else_steps or []
7979

80-
root_path = f"Steps.{name}"
81-
root_prop = Properties(path=root_path)
82-
root_prop.__dict__["Outcome"] = Properties(f"{root_path}.Outcome")
80+
root_prop = Properties(step_name=name)
81+
root_prop.__dict__["Outcome"] = Properties(step_name=name, path="Outcome")
8382
self._properties = root_prop
8483

8584
@property
@@ -91,6 +90,11 @@ def arguments(self) -> RequestType:
9190
ElseSteps=list_to_request(self.else_steps),
9291
)
9392

93+
@property
94+
def step_only_arguments(self):
95+
"""Argument dict pertaining to the step only, and not the `if_steps` or `else_steps`."""
96+
return self.arguments["Conditions"]
97+
9498
@property
9599
def properties(self):
96100
"""A simple Properties object with `Outcome` as the only property"""
@@ -126,5 +130,10 @@ def expr(self):
126130
}
127131
}
128132

133+
@property
134+
def depends_on(self) -> List[str]:
135+
"""List of step names that this function depends on."""
136+
return [self.step.name]
137+
129138

130139
JsonGet = deprecated_class(JsonGet, "JsonGet")

src/sagemaker/workflow/emr_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
self.args = emr_step_args
9595
self.cache_config = cache_config
9696

97-
root_property = Properties(path=f"Steps.{name}", shape_name="Step", service_name="emr")
97+
root_property = Properties(step_name=name, shape_name="Step", service_name="emr")
9898
root_property.__dict__["ClusterId"] = cluster_id
9999
self._properties = root_property
100100

src/sagemaker/workflow/entities.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,8 @@ def to_string(self):
102102
@abc.abstractmethod
103103
def expr(self) -> RequestType:
104104
"""Get the expression structure for workflow service calls."""
105+
106+
@property
107+
@abc.abstractmethod
108+
def depends_on(self) -> List[str]:
109+
"""List of step names that this function depends on."""

src/sagemaker/workflow/execution_variables.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Pipeline parameters and conditions for workflow."""
1414
from __future__ import absolute_import
1515

16+
from typing import List
1617
from sagemaker.workflow.entities import (
1718
RequestType,
1819
PipelineVariable,
@@ -42,6 +43,11 @@ def expr(self) -> RequestType:
4243
"""The 'Get' expression dict for an `ExecutionVariable`."""
4344
return {"Get": f"Execution.{self.name}"}
4445

46+
@property
47+
def depends_on(self) -> List[str]:
48+
"""List of step names that this function depends on."""
49+
return []
50+
4551

4652
class ExecutionVariables:
4753
"""Provide access to all available execution variables:

src/sagemaker/workflow/functions.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ def expr(self):
6464
},
6565
}
6666

67+
@property
68+
def depends_on(self) -> List[str]:
69+
"""List of step names that this function depends on."""
70+
steps = []
71+
for value in self.values:
72+
if hasattr(value, "depends_on"):
73+
steps.extend(value.depends_on)
74+
return steps
75+
6776

6877
@attr.s
6978
class JsonGet(PipelineVariable):
@@ -96,3 +105,8 @@ def expr(self):
96105
"Path": self.json_path,
97106
}
98107
}
108+
109+
@property
110+
def depends_on(self) -> List[str]:
111+
"""List of step names that this function depends on."""
112+
return [self.step_name]

src/sagemaker/workflow/lambda_step.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,12 @@ def __init__(
115115
self.cache_config = cache_config
116116
self.inputs = inputs if inputs is not None else {}
117117

118-
root_path = f"Steps.{name}"
119-
root_prop = Properties(path=root_path)
118+
root_prop = Properties(step_name=name)
120119

121120
property_dict = {}
122121
for output in self.outputs:
123122
property_dict[output.output_name] = Properties(
124-
f"{root_path}.OutputParameters['{output.output_name}']"
123+
step_name=name, path=f"OutputParameters['{output.output_name}']"
125124
)
126125

127126
root_prop.__dict__["Outputs"] = property_dict

src/sagemaker/workflow/parameters.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def expr(self) -> Dict[str, str]:
9090
"""The 'Get' expression dict for a `Parameter`."""
9191
return Parameter._expr(self.name)
9292

93+
@property
94+
def depends_on(self) -> List[str]:
95+
"""List of step names that this function depends on."""
96+
return []
97+
9398
@classmethod
9499
def _expr(cls, name):
95100
"""An internal classmethod for the 'Get' expression dict for a `Parameter`.

src/sagemaker/workflow/pipeline.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717

1818
from copy import deepcopy
19-
from typing import Any, Dict, List, Sequence, Union, Optional
19+
from typing import Any, Dict, List, Sequence, Union, Optional, Set
2020

2121
import attr
2222
import botocore
@@ -39,6 +39,7 @@
3939
from sagemaker.workflow.properties import Properties
4040
from sagemaker.workflow.steps import Step
4141
from sagemaker.workflow.step_collections import StepCollection
42+
from sagemaker.workflow.condition_step import ConditionStep
4243
from sagemaker.workflow.utilities import list_to_request
4344

4445

@@ -534,3 +535,138 @@ def wait(self, delay=30, max_attempts=60):
534535
waiter_id, model, self.sagemaker_session.sagemaker_client
535536
)
536537
waiter.wait(PipelineExecutionArn=self.arn)
538+
539+
540+
class PipelineGraph:
541+
"""Helper class representing the Pipeline Directed Acyclic Graph (DAG)
542+
543+
Attributes:
544+
steps (Sequence[Union[Step, StepCollection]]): Sequence of `Step`s and/or `StepCollection`s
545+
that represent each node in the pipeline DAG
546+
"""
547+
548+
def __init__(self, steps: Sequence[Union[Step, StepCollection]]):
549+
if self.is_duplicate_step_name(steps):
550+
raise ValueError("Pipeline steps cannot have duplicate names.")
551+
self.step_map = PipelineGraph._generate_step_map(steps)
552+
self.adjacency_list = self._initialize_adjacency_list()
553+
if self.is_cyclic():
554+
raise ValueError("Cycle detected in pipeline step graph.")
555+
556+
@staticmethod
557+
def _generate_step_map(
558+
steps: Sequence[Union[Step, StepCollection]]
559+
) -> Dict[str, Union[Step, StepCollection]]:
560+
"""Helper method to create a mapping from Step/Step Collection name to itself."""
561+
step_map = {}
562+
for step in steps:
563+
if isinstance(step, Step):
564+
step_map[step.name] = step
565+
if isinstance(step, ConditionStep):
566+
step_map.update(
567+
PipelineGraph._generate_step_map(step.if_steps + step.else_steps)
568+
)
569+
elif isinstance(step, StepCollection):
570+
step_map[step.name] = step
571+
for inner_step in step.steps:
572+
step_map[inner_step.name] = inner_step
573+
return step_map
574+
575+
@classmethod
576+
def from_pipeline(cls, pipeline: Pipeline):
577+
"""Create a PipelineGraph object from the Pipeline object."""
578+
return cls(pipeline.steps)
579+
580+
def _initialize_adjacency_list(self) -> Dict[str, Set[str]]:
581+
"""Generate an adjacency list representing the step dependency DAG in this pipeline."""
582+
from collections import defaultdict
583+
584+
dependency_list = defaultdict(set)
585+
for step in self.step_map.values():
586+
if isinstance(step, Step):
587+
dependency_list[step.name].update(step._find_step_dependencies(self.step_map))
588+
elif isinstance(step, StepCollection):
589+
step_collection_dependencies = step._find_step_dependencies(self.step_map)
590+
for k, v in step_collection_dependencies.items():
591+
dependency_list[k].update(v)
592+
593+
if isinstance(step, ConditionStep):
594+
for child_step in step.if_steps + step.else_steps:
595+
if isinstance(child_step, Step):
596+
dependency_list[child_step.name].add(step.name)
597+
elif isinstance(child_step, StepCollection):
598+
child_first_step = self.step_map[child_step.name].steps[0].name
599+
dependency_list[child_first_step].add(step.name)
600+
601+
adjacency_list = {}
602+
for step in dependency_list:
603+
for step_dependency in dependency_list[step]:
604+
adjacency_list[step_dependency] = list(
605+
set(adjacency_list.get(step_dependency, []) + [step])
606+
)
607+
for step in dependency_list:
608+
if step not in adjacency_list:
609+
adjacency_list[step] = []
610+
return adjacency_list
611+
612+
def is_duplicate_step_name(self, steps: Sequence[Union[Step, StepCollection]]) -> bool:
613+
"""Check if the provided step names have any duplicates.
614+
615+
Returns true if there are one or more duplicates, false otherwise.
616+
"""
617+
step_names = set()
618+
for step in steps:
619+
if step.name in step_names:
620+
return True
621+
step_names.add(step.name)
622+
return False
623+
624+
def is_cyclic(self) -> bool:
625+
"""Check if this pipeline graph is cyclic.
626+
627+
Returns true if it is cyclic, false otherwise.
628+
"""
629+
630+
def is_cyclic_helper(current_step):
631+
visited_steps.add(current_step)
632+
recurse_steps.add(current_step)
633+
for child_step in self.adjacency_list[current_step]:
634+
if child_step in recurse_steps:
635+
return True
636+
if child_step not in visited_steps:
637+
if is_cyclic_helper(child_step):
638+
return True
639+
recurse_steps.remove(current_step)
640+
return False
641+
642+
visited_steps = set()
643+
recurse_steps = set()
644+
for step in self.adjacency_list:
645+
if step not in visited_steps:
646+
if is_cyclic_helper(step):
647+
return True
648+
return False
649+
650+
def __iter__(self):
651+
"""Perform topological sort traversal of the Pipeline Graph."""
652+
653+
def topological_sort(current_step):
654+
visited_steps.add(current_step)
655+
for child_step in self.adjacency_list[current_step]:
656+
if child_step not in visited_steps:
657+
topological_sort(child_step)
658+
self.stack.append(current_step)
659+
660+
visited_steps = set()
661+
self.stack = [] # pylint: disable=W0201
662+
for step in self.adjacency_list:
663+
if step not in visited_steps:
664+
topological_sort(step)
665+
return self
666+
667+
def __next__(self) -> Step:
668+
"""Return the next Step node from the Topological sort order."""
669+
670+
while self.stack:
671+
return self.step_map.get(self.stack.pop())
672+
raise StopIteration

0 commit comments

Comments
 (0)