Skip to content

Commit ce7f47f

Browse files
author
Namrata Madan
committed
feature: add helper method to generate pipeline adjacency list
1 parent 255a339 commit ce7f47f

25 files changed

+164
-71
lines changed

src/sagemaker/workflow/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def __init__(
365365
self.kwargs = kwargs
366366
self.container_def_list = container_def_list
367367

368-
self._properties = Properties(path=f"Steps.{name}", shape_name="DescribeModelPackageOutput")
368+
self._properties = Properties(step_name=name, shape_name="DescribeModelPackageOutput")
369369

370370
@property
371371
def arguments(self) -> RequestType:

src/sagemaker/workflow/callback_step.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,12 @@ def __init__(
112112
self.cache_config = cache_config
113113
self.inputs = inputs
114114

115-
root_path = f"Steps.{name}"
116-
root_prop = Properties(path=root_path)
115+
root_prop = Properties(step_name=name)
117116

118117
property_dict = {}
119118
for output in outputs:
120119
property_dict[output.output_name] = Properties(
121-
f"{root_path}.OutputParameters['{output.output_name}']"
120+
step_name=name, path=f"OutputParameters['{output.output_name}']"
122121
)
123122

124123
root_prop.__dict__["Outputs"] = property_dict

src/sagemaker/workflow/clarify_check_step.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,13 +236,12 @@ def __init__(
236236
self._generate_processing_job_analysis_config(), self._baselining_processor
237237
)
238238

239-
root_path = f"Steps.{name}"
240-
root_prop = Properties(path=root_path)
239+
root_prop = Properties(step_name=name)
241240
root_prop.__dict__["CalculatedBaselineConstraints"] = Properties(
242-
f"{root_path}.CalculatedBaselineConstraints"
241+
step_name=name, path="CalculatedBaselineConstraints"
243242
)
244243
root_prop.__dict__["BaselineUsedForDriftCheckConstraints"] = Properties(
245-
f"{root_path}.BaselineUsedForDriftCheckConstraints"
244+
step_name=name, path="BaselineUsedForDriftCheckConstraints"
246245
)
247246
self._properties = root_prop
248247

src/sagemaker/workflow/condition_step.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,8 @@ def __init__(
7777
self.if_steps = if_steps or []
7878
self.else_steps = else_steps or []
7979

80-
root_path = f"Steps.{name}"
81-
root_prop = Properties(path=root_path)
82-
root_prop.__dict__["Outcome"] = Properties(f"{root_path}.Outcome")
80+
root_prop = Properties(step_name=name)
81+
root_prop.__dict__["Outcome"] = Properties(step_name=name, path="Outcome")
8382
self._properties = root_prop
8483

8584
@property

src/sagemaker/workflow/emr_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
self.args = emr_step_args
9595
self.cache_config = cache_config
9696

97-
root_property = Properties(path=f"Steps.{name}", shape_name="Step", service_name="emr")
97+
root_property = Properties(step_name=name, shape_name="Step", service_name="emr")
9898
root_property.__dict__["ClusterId"] = cluster_id
9999
self._properties = root_property
100100

src/sagemaker/workflow/lambda_step.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,12 @@ def __init__(
115115
self.cache_config = cache_config
116116
self.inputs = inputs if inputs is not None else {}
117117

118-
root_path = f"Steps.{name}"
119-
root_prop = Properties(path=root_path)
118+
root_prop = Properties(step_name=name)
120119

121120
property_dict = {}
122121
for output in self.outputs:
123122
property_dict[output.output_name] = Properties(
124-
f"{root_path}.OutputParameters['{output.output_name}']"
123+
step_name=name, path=f"OutputParameters['{output.output_name}']"
125124
)
126125

127126
root_prop.__dict__["Outputs"] = property_dict

src/sagemaker/workflow/pipeline.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,16 @@ def _interpolate_step_collection_name_in_depends_on(self, step_requests: dict):
322322
depends_on.append(depend_step_name)
323323
step_request["DependsOn"] = depends_on
324324

325+
def _generate_adjacency_list(self):
326+
"""Generate an adjacency list representing the step dependency DAG in this pipeline."""
327+
adjacency_list = {}
328+
for step in self.steps:
329+
if isinstance(step, Step):
330+
adjacency_list[step.name] = step._find_step_dependencies()
331+
elif isinstance(step, StepCollection):
332+
adjacency_list.update(step._find_step_dependencies())
333+
return adjacency_list
334+
325335

326336
def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
327337
"""Formats start parameter overrides as a list of dicts.

src/sagemaker/workflow/properties.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ class Properties(PipelineVariable, metaclass=PropertiesMeta):
5050

5151
def __init__(
5252
self,
53-
path: str,
53+
step_name: str,
54+
path: str = None,
5455
shape_name: str = None,
5556
shape_names: List[str] = None,
5657
service_name: str = "sagemaker",
@@ -62,7 +63,9 @@ def __init__(
6263
shape_name (str): The botocore service model shape name.
6364
shape_names (str): A List of the botocore service model shape name.
6465
"""
65-
self._path = path
66+
self.step_name = step_name
67+
prefix = f"Steps.{step_name}"
68+
self._path = prefix if path is None else f"{prefix}.{path}"
6669
shape_names = [] if shape_names is None else shape_names
6770
self._shape_names = shape_names if shape_name is None else [shape_name] + shape_names
6871

@@ -78,15 +81,24 @@ def __init__(
7881
for key, info in members.items():
7982
if shapes.get(info["shape"], {}).get("type") == "list":
8083
self.__dict__[key] = PropertiesList(
81-
f"{path}.{key}", info["shape"], service_name
84+
step_name=step_name,
85+
path=".".join(filter(None, (path, key))),
86+
shape_name=info["shape"],
87+
service_name=service_name,
8288
)
8389
elif shapes.get(info["shape"], {}).get("type") == "map":
8490
self.__dict__[key] = PropertiesMap(
85-
f"{path}.{key}", info["shape"], service_name
91+
step_name=step_name,
92+
path=".".join(filter(None, (path, key))),
93+
shape_name=info["shape"],
94+
service_name=service_name,
8695
)
8796
else:
8897
self.__dict__[key] = Properties(
89-
f"{path}.{key}", info["shape"], service_name=service_name
98+
step_name=step_name,
99+
path=".".join(filter(None, (path, key))),
100+
shape_name=info["shape"],
101+
service_name=service_name,
90102
)
91103

92104
@property
@@ -98,17 +110,21 @@ def expr(self):
98110
class PropertiesList(Properties):
99111
"""PropertiesList for use in workflow expressions."""
100112

101-
def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"):
113+
def __init__(
114+
self, step_name: str, path: str, shape_name: str = None, service_name: str = "sagemaker"
115+
):
102116
"""Create a PropertiesList instance representing the given shape.
103117
104118
Args:
105119
path (str): The parent path of the PropertiesList instance.
106120
shape_name (str): The botocore service model shape name.
107121
service_name (str): The botocore service name.
108122
"""
109-
super(PropertiesList, self).__init__(path, shape_name)
123+
super(PropertiesList, self).__init__(step_name, path, shape_name)
124+
self.step_name = step_name
110125
self.shape_name = shape_name
111126
self.service_name = service_name
127+
self.path = path
112128
self._items: Dict[Union[int, str], Properties] = dict()
113129

114130
def __getitem__(self, item: Union[int, str]):
@@ -121,9 +137,9 @@ def __getitem__(self, item: Union[int, str]):
121137
shape = Properties._shapes_map.get(self.service_name, {}).get(self.shape_name)
122138
member = shape["member"]["shape"]
123139
if isinstance(item, str):
124-
property_item = Properties(f"{self._path}['{item}']", member)
140+
property_item = Properties(self.step_name, f"{self.path}['{item}']", member)
125141
else:
126-
property_item = Properties(f"{self._path}[{item}]", member)
142+
property_item = Properties(self.step_name, f"{self.path}[{item}]", member)
127143
self._items[item] = property_item
128144

129145
return self._items.get(item)
@@ -132,17 +148,21 @@ def __getitem__(self, item: Union[int, str]):
132148
class PropertiesMap(Properties):
133149
"""PropertiesMap for use in workflow expressions."""
134150

135-
def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"):
151+
def __init__(
152+
self, step_name: str, path: str, shape_name: str = None, service_name: str = "sagemaker"
153+
):
136154
"""Create a PropertiesMap instance representing the given shape.
137155
138156
Args:
139157
path (str): The parent path of the PropertiesMap instance.
140158
shape_name (str): The botocore sagemaker service model shape name.
141159
service_name (str): The botocore service name.
142160
"""
143-
super(PropertiesMap, self).__init__(path, shape_name)
161+
super(PropertiesMap, self).__init__(step_name, path, shape_name)
162+
self.step_name = step_name
144163
self.shape_name = shape_name
145164
self.service_name = service_name
165+
self.path = path
146166
self._items: Dict[Union[int, str], Properties] = dict()
147167

148168
def __getitem__(self, item: Union[int, str]):
@@ -155,9 +175,9 @@ def __getitem__(self, item: Union[int, str]):
155175
shape = Properties._shapes_map.get(self.service_name, {}).get(self.shape_name)
156176
member = shape["value"]["shape"]
157177
if isinstance(item, str):
158-
property_item = Properties(f"{self._path}['{item}']", member)
178+
property_item = Properties(self.step_name, f"{self.path}['{item}']", member)
159179
else:
160-
property_item = Properties(f"{self._path}[{item}]", member)
180+
property_item = Properties(self.step_name, f"{self.path}[{item}]", member)
161181
self._items[item] = property_item
162182

163183
return self._items.get(item)

src/sagemaker/workflow/quality_check_step.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,19 +205,18 @@ def __init__(
205205
],
206206
)
207207

208-
root_path = f"Steps.{name}"
209-
root_prop = Properties(path=root_path)
208+
root_prop = Properties(step_name=name)
210209
root_prop.__dict__["CalculatedBaselineConstraints"] = Properties(
211-
f"{root_path}.CalculatedBaselineConstraints"
210+
step_name=name, path="CalculatedBaselineConstraints"
212211
)
213212
root_prop.__dict__["CalculatedBaselineStatistics"] = Properties(
214-
f"{root_path}.CalculatedBaselineStatistics"
213+
step_name=name, path="CalculatedBaselineStatistics"
215214
)
216215
root_prop.__dict__["BaselineUsedForDriftCheckStatistics"] = Properties(
217-
f"{root_path}.BaselineUsedForDriftCheckStatistics"
216+
step_name=name, path="BaselineUsedForDriftCheckStatistics"
218217
)
219218
root_prop.__dict__["BaselineUsedForDriftCheckConstraints"] = Properties(
220-
f"{root_path}.BaselineUsedForDriftCheckConstraints"
219+
step_name=name, path="BaselineUsedForDriftCheckConstraints"
221220
)
222221
self._properties = root_prop
223222

src/sagemaker/workflow/step_collections.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ def properties(self):
5454
size = len(self.steps)
5555
return self.steps[size - 1].properties
5656

57+
def _find_step_dependencies(self):
58+
"""Find the step names this step collection is dependent on."""
59+
dependencies = {}
60+
for step in self.steps:
61+
dependencies[step.name] = step._find_step_dependencies()
62+
return dependencies
63+
5764

5865
class RegisterModel(StepCollection): # pragma: no cover
5966
"""Register Model step collection for workflow."""

src/sagemaker/workflow/steps.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,33 @@ def _resolve_depends_on(
148148
raise ValueError(f"Invalid input step name: {step}")
149149
return depends_on
150150

151+
def _find_step_dependencies(self):
152+
"""Find the step names this step is dependent on."""
153+
154+
def _find_dependencies_in_step_arguments(obj):
155+
dependencies = set()
156+
if isinstance(obj, dict):
157+
for value in obj.values():
158+
if isinstance(value, Properties):
159+
dependencies.add(value.step_name)
160+
dependencies.update(_find_dependencies_in_step_arguments(value))
161+
if isinstance(obj, list):
162+
for item in obj:
163+
if isinstance(item, Properties):
164+
dependencies.add(item.step_name)
165+
dependencies.update(_find_dependencies_in_step_arguments(item))
166+
return dependencies
167+
168+
step_dependencies = set()
169+
if self.depends_on:
170+
for depends_on_step in self.depends_on:
171+
if isinstance(depends_on_step, str):
172+
step_dependencies.add(depends_on_step)
173+
elif isinstance(depends_on_step, Step):
174+
step_dependencies.add(depends_on_step.name)
175+
step_dependencies.update(_find_dependencies_in_step_arguments(self.arguments))
176+
return list(step_dependencies)
177+
151178

152179
@attr.s
153180
class CacheConfig:
@@ -290,9 +317,7 @@ def __init__(
290317
self.estimator = estimator
291318
self.inputs = inputs
292319

293-
self._properties = Properties(
294-
path=f"Steps.{name}", shape_name="DescribeTrainingJobResponse"
295-
)
320+
self._properties = Properties(step_name=name, shape_name="DescribeTrainingJobResponse")
296321
self.cache_config = cache_config
297322

298323
if self.cache_config:
@@ -430,7 +455,7 @@ def __init__(
430455
self.model = model
431456
self.inputs = inputs or CreateModelInput()
432457

433-
self._properties = Properties(path=f"Steps.{name}", shape_name="DescribeModelOutput")
458+
self._properties = Properties(step_name=name, shape_name="DescribeModelOutput")
434459

435460
# TODO: add public document link here once ready
436461
warnings.warn(
@@ -526,9 +551,7 @@ def __init__(
526551
self.transformer = transformer
527552
self.inputs = inputs
528553
self.cache_config = cache_config
529-
self._properties = Properties(
530-
path=f"Steps.{name}", shape_name="DescribeTransformJobResponse"
531-
)
554+
self._properties = Properties(step_name=name, shape_name="DescribeTransformJobResponse")
532555

533556
if not self.step_args:
534557
if inputs is None:
@@ -652,9 +675,7 @@ def __init__(
652675
self.job_name = None
653676
self.kms_key = kms_key
654677
self.cache_config = cache_config
655-
self._properties = Properties(
656-
path=f"Steps.{name}", shape_name="DescribeProcessingJobResponse"
657-
)
678+
self._properties = Properties(step_name=name, shape_name="DescribeProcessingJobResponse")
658679

659680
if not self.step_args:
660681
# Examine why run method in `sagemaker.processing.Processor`
@@ -806,7 +827,7 @@ def __init__(
806827
self.inputs = inputs
807828
self.job_arguments = job_arguments
808829
self._properties = Properties(
809-
path=f"Steps.{name}",
830+
step_name=name,
810831
shape_names=[
811832
"DescribeHyperParameterTuningJobResponse",
812833
"ListTrainingJobsForHyperParameterTuningJobResponse",

tests/unit/sagemaker/workflow/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, name, display_name=None, description=None, depends_on=None):
4242
super(CustomStep, self).__init__(
4343
name, display_name, description, StepTypeEnum.TRAINING, depends_on
4444
)
45-
self._properties = Properties(path=f"Steps.{name}")
45+
self._properties = Properties(name)
4646

4747
@property
4848
def arguments(self):

tests/unit/sagemaker/workflow/test_callback_step.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sagemaker.workflow.parameters import ParameterInteger, ParameterString
2222
from sagemaker.workflow.pipeline import Pipeline
2323
from sagemaker.workflow.callback_step import CallbackStep, CallbackOutput, CallbackOutputTypeEnum
24-
from tests.unit.sagemaker.workflow.helpers import CustomStep
24+
from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered
2525

2626

2727
@pytest.fixture
@@ -52,6 +52,7 @@ def test_callback_step():
5252
],
5353
"Arguments": {"arg1": "foo", "arg2": 5, "arg3": param},
5454
}
55+
assert ordered(cb_step._find_step_dependencies()) == ["SecondTestStep", "TestStep"]
5556

5657

5758
def test_callback_step_default_values():
@@ -75,6 +76,7 @@ def test_callback_step_default_values():
7576
],
7677
"Arguments": {"arg1": "foo", "arg2": 5, "arg3": param},
7778
}
79+
assert ordered(cb_step._find_step_dependencies()) == ["SecondTestStep", "TestStep"]
7880

7981

8082
def test_callback_step_output_expr():
@@ -95,6 +97,7 @@ def test_callback_step_output_expr():
9597
assert cb_step.properties.Outputs["output2"].expr == {
9698
"Get": "Steps.MyCallbackStep.OutputParameters['output2']"
9799
}
100+
assert cb_step._find_step_dependencies() == ["TestStep"]
98101

99102

100103
def test_pipeline_interpolates_callback_outputs():
@@ -156,3 +159,6 @@ def test_pipeline_interpolates_callback_outputs():
156159
},
157160
],
158161
}
162+
assert ordered(pipeline._generate_adjacency_list()) == ordered(
163+
{"MyCallbackStep1": ["TestStep"], "MyCallbackStep2": ["TestStep"], "TestStep": []}
164+
)

0 commit comments

Comments
 (0)