Skip to content

feature: add helper method to generate pipeline adjacency list #3128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def __init__(
self.kwargs = kwargs
self.container_def_list = container_def_list

self._properties = Properties(path=f"Steps.{name}", shape_name="DescribeModelPackageOutput")
self._properties = Properties(step_name=name, shape_name="DescribeModelPackageOutput")

@property
def arguments(self) -> RequestType:
Expand Down
5 changes: 2 additions & 3 deletions src/sagemaker/workflow/callback_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,12 @@ def __init__(
self.cache_config = cache_config
self.inputs = inputs

root_path = f"Steps.{name}"
root_prop = Properties(path=root_path)
root_prop = Properties(step_name=name)

property_dict = {}
for output in outputs:
property_dict[output.output_name] = Properties(
f"{root_path}.OutputParameters['{output.output_name}']"
step_name=name, path=f"OutputParameters['{output.output_name}']"
)

root_prop.__dict__["Outputs"] = property_dict
Expand Down
7 changes: 3 additions & 4 deletions src/sagemaker/workflow/clarify_check_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,12 @@ def __init__(
self._generate_processing_job_analysis_config(), self._baselining_processor
)

root_path = f"Steps.{name}"
root_prop = Properties(path=root_path)
root_prop = Properties(step_name=name)
root_prop.__dict__["CalculatedBaselineConstraints"] = Properties(
f"{root_path}.CalculatedBaselineConstraints"
step_name=name, path="CalculatedBaselineConstraints"
)
root_prop.__dict__["BaselineUsedForDriftCheckConstraints"] = Properties(
f"{root_path}.BaselineUsedForDriftCheckConstraints"
step_name=name, path="BaselineUsedForDriftCheckConstraints"
)
self._properties = root_prop

Expand Down
15 changes: 12 additions & 3 deletions src/sagemaker/workflow/condition_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ def __init__(
self.if_steps = if_steps or []
self.else_steps = else_steps or []

root_path = f"Steps.{name}"
root_prop = Properties(path=root_path)
root_prop.__dict__["Outcome"] = Properties(f"{root_path}.Outcome")
root_prop = Properties(step_name=name)
root_prop.__dict__["Outcome"] = Properties(step_name=name, path="Outcome")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.

Modifying object.__dict__ directly or writing to an instance of a class __dict__ attribute directly is not recommended. Inside every module is a __dict__ object.dict attribute which contains its symbol table. If you modify object.__dict__, then the symbol table is changed. Also, direct assignment to the __dict__ attribute is not possible.

Learn more

self._properties = root_prop

@property
Expand All @@ -91,6 +90,11 @@ def arguments(self) -> RequestType:
ElseSteps=list_to_request(self.else_steps),
)

@property
def step_only_arguments(self):
"""Argument dict pertaining to the step only, and not the `if_steps` or `else_steps`."""
return self.conditions

@property
def properties(self):
"""A simple Properties object with `Outcome` as the only property"""
Expand Down Expand Up @@ -126,5 +130,10 @@ def expr(self):
}
}

@property
def _referenced_steps(self) -> List[str]:
"""List of step names that this function depends on."""
return [self.step.name]


JsonGet = deprecated_class(JsonGet, "JsonGet")
42 changes: 42 additions & 0 deletions src/sagemaker/workflow/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"""
from __future__ import absolute_import

import abc

from enum import Enum
from typing import Dict, List, Union

Expand All @@ -33,6 +35,7 @@
from sagemaker.workflow.execution_variables import ExecutionVariable
from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.entities import PipelineVariable

# TODO: consider base class for those with an expr method, rather than defining a type here
ConditionValueType = Union[ExecutionVariable, Parameter, Properties]
Expand Down Expand Up @@ -61,6 +64,11 @@ class Condition(Entity):

condition_type: ConditionTypeEnum = attr.ib(factory=ConditionTypeEnum.factory)

@property
@abc.abstractmethod
def _referenced_steps(self) -> List[str]:
"""List of step names that this function depends on."""


@attr.s
class ConditionComparison(Condition):
Expand All @@ -84,6 +92,16 @@ def to_request(self) -> RequestType:
"RightValue": primitive_or_expr(self.right),
}

@property
def _referenced_steps(self) -> List[str]:
"""List of step names that this function depends on."""
steps = []
if isinstance(self.left, PipelineVariable):
steps.extend(self.left._referenced_steps)
if isinstance(self.right, PipelineVariable):
steps.extend(self.right._referenced_steps)
return steps


class ConditionEquals(ConditionComparison):
"""A condition for equality comparisons."""
Expand Down Expand Up @@ -213,6 +231,17 @@ def to_request(self) -> RequestType:
"Values": [primitive_or_expr(in_value) for in_value in self.in_values],
}

@property
def _referenced_steps(self) -> List[str]:
"""List of step names that this function depends on."""
steps = []
if isinstance(self.value, PipelineVariable):
steps.extend(self.value._referenced_steps)
for in_value in self.in_values:
if isinstance(in_value, PipelineVariable):
steps.extend(in_value._referenced_steps)
return steps


class ConditionNot(Condition):
"""A condition for negating another `Condition`."""
Expand All @@ -230,6 +259,11 @@ def to_request(self) -> RequestType:
"""Get the request structure for workflow service calls."""
return {"Type": self.condition_type.value, "Expression": self.expression.to_request()}

@property
def _referenced_steps(self) -> List[str]:
"""List of step names that this function depends on."""
return self.expression._referenced_steps


class ConditionOr(Condition):
"""A condition for taking the logical OR of a list of `Condition` instances."""
Expand All @@ -250,6 +284,14 @@ def to_request(self) -> RequestType:
"Conditions": [condition.to_request() for condition in self.conditions],
}

@property
def _referenced_steps(self) -> List[str]:
"""List of step names that this function depends on."""
steps = []
for condition in self.conditions:
steps.extend(condition._referenced_steps)
return steps


def primitive_or_expr(
value: Union[ExecutionVariable, Expression, PrimitiveType, Parameter, Properties]
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/workflow/emr_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
self.args = emr_step_args
self.cache_config = cache_config

root_property = Properties(path=f"Steps.{name}", shape_name="Step", service_name="emr")
root_property = Properties(step_name=name, shape_name="Step", service_name="emr")
root_property.__dict__["ClusterId"] = cluster_id
self._properties = root_property

Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/workflow/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,8 @@ def to_string(self):
@abc.abstractmethod
def expr(self) -> RequestType:
"""Get the expression structure for workflow service calls."""

@property
@abc.abstractmethod
def _referenced_steps(self) -> List[str]:
"""List of step names that this function depends on."""
6 changes: 6 additions & 0 deletions src/sagemaker/workflow/execution_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""Pipeline parameters and conditions for workflow."""
from __future__ import absolute_import

from typing import List
from sagemaker.workflow.entities import (
RequestType,
PipelineVariable,
Expand Down Expand Up @@ -42,6 +43,11 @@ def expr(self) -> RequestType:
"""The 'Get' expression dict for an `ExecutionVariable`."""
return {"Get": f"Execution.{self.name}"}

@property
def _referenced_steps(self) -> List[str]:
"""List of step names that this function depends on."""
return []


class ExecutionVariables:
"""Provide access to all available execution variables:
Expand Down
14 changes: 14 additions & 0 deletions src/sagemaker/workflow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ def expr(self):
},
}

@property
def _referenced_steps(self) -> List[str]:
"""List of step names that this function depends on."""
steps = []
for value in self.values:
if isinstance(value, PipelineVariable):
steps.extend(value._referenced_steps)
return steps


@attr.s
class JsonGet(PipelineVariable):
Expand Down Expand Up @@ -96,3 +105,8 @@ def expr(self):
"Path": self.json_path,
}
}

@property
def _referenced_steps(self) -> List[str]:
"""List of step names that this function depends on."""
return [self.step_name]
5 changes: 2 additions & 3 deletions src/sagemaker/workflow/lambda_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,12 @@ def __init__(
self.cache_config = cache_config
self.inputs = inputs if inputs is not None else {}

root_path = f"Steps.{name}"
root_prop = Properties(path=root_path)
root_prop = Properties(step_name=name)

property_dict = {}
for output in self.outputs:
property_dict[output.output_name] = Properties(
f"{root_path}.OutputParameters['{output.output_name}']"
step_name=name, path=f"OutputParameters['{output.output_name}']"
)

root_prop.__dict__["Outputs"] = property_dict
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/workflow/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def expr(self) -> Dict[str, str]:
"""The 'Get' expression dict for a `Parameter`."""
return Parameter._expr(self.name)

@property
def _referenced_steps(self) -> List[str]:
"""List of step names that this function depends on."""
return []

@classmethod
def _expr(cls, name):
"""An internal classmethod for the 'Get' expression dict for a `Parameter`.
Expand Down
111 changes: 111 additions & 0 deletions src/sagemaker/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.steps import Step
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.utilities import list_to_request


Expand Down Expand Up @@ -534,3 +535,113 @@ def wait(self, delay=30, max_attempts=60):
waiter_id, model, self.sagemaker_session.sagemaker_client
)
waiter.wait(PipelineExecutionArn=self.arn)


class PipelineGraph:
"""Helper class representing the Pipeline Directed Acyclic Graph (DAG)

Attributes:
steps (Sequence[Union[Step, StepCollection]]): Sequence of `Step`s and/or `StepCollection`s
that represent each node in the pipeline DAG
"""

def __init__(self, steps: Sequence[Union[Step, StepCollection]]):
self.step_map = {}
self._generate_step_map(steps)
self.adjacency_list = self._initialize_adjacency_list()
if self.is_cyclic():
raise ValueError("Cycle detected in pipeline step graph.")

def _generate_step_map(self, steps: Sequence[Union[Step, StepCollection]]):
"""Helper method to create a mapping from Step/Step Collection name to itself."""
for step in steps:
if step.name in self.step_map:
raise ValueError("Pipeline steps cannot have duplicate names.")
self.step_map[step.name] = step
if isinstance(step, ConditionStep):
self._generate_step_map(step.if_steps + step.else_steps)
if isinstance(step, StepCollection):
self._generate_step_map(step.steps)

@classmethod
def from_pipeline(cls, pipeline: Pipeline):
"""Create a PipelineGraph object from the Pipeline object."""
return cls(pipeline.steps)

def _initialize_adjacency_list(self) -> Dict[str, List[str]]:
"""Generate an adjacency list representing the step dependency DAG in this pipeline."""
from collections import defaultdict

dependency_list = defaultdict(set)
for step in self.step_map.values():
if isinstance(step, Step):
dependency_list[step.name].update(step._find_step_dependencies(self.step_map))

if isinstance(step, ConditionStep):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if else branches in the condition step are already flattened out. So it just need to be treated as a single node.

The condition step itself has condition expressions, which may contain property references.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment can be ignored.

for child_step in step.if_steps + step.else_steps:
if isinstance(child_step, Step):
dependency_list[child_step.name].add(step.name)
elif isinstance(child_step, StepCollection):
child_first_step = self.step_map[child_step.name].steps[0].name
dependency_list[child_first_step].add(step.name)

adjacency_list = {}
for step in dependency_list:
for step_dependency in dependency_list[step]:
adjacency_list[step_dependency] = list(
set(adjacency_list.get(step_dependency, []) + [step])
)
for step in dependency_list:
if step not in adjacency_list:
adjacency_list[step] = []
return adjacency_list

def is_cyclic(self) -> bool:
"""Check if this pipeline graph is cyclic.

Returns true if it is cyclic, false otherwise.
"""

def is_cyclic_helper(current_step):
visited_steps.add(current_step)
recurse_steps.add(current_step)
for child_step in self.adjacency_list[current_step]:
if child_step in recurse_steps:
return True
if child_step not in visited_steps:
if is_cyclic_helper(child_step):
return True
recurse_steps.remove(current_step)
return False

visited_steps = set()
recurse_steps = set()
for step in self.adjacency_list:
if step not in visited_steps:
if is_cyclic_helper(step):
return True
return False

def __iter__(self):
"""Perform topological sort traversal of the Pipeline Graph."""

def topological_sort(current_step):
visited_steps.add(current_step)
for child_step in self.adjacency_list[current_step]:
if child_step not in visited_steps:
topological_sort(child_step)
self.stack.append(current_step)

visited_steps = set()
self.stack = [] # pylint: disable=W0201
for step in self.adjacency_list:
if step not in visited_steps:
topological_sort(step)
return self

def __next__(self) -> Step:
"""Return the next Step node from the Topological sort order."""

while self.stack:
return self.step_map.get(self.stack.pop())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.

You are using the get method without a default argument to return the value of a key in a dictionary. We recommended that you use a default argument so that if the value for your key is not found, a default value is returned. If a default value is not provided and the key is not found, then None is returned.

Learn more

raise StopIteration
Loading