Skip to content

Commit 14d2454

Browse files
authored
Unify signatures of graph_replace and clone_replace (#398)
* more type hints
1 parent e9a7d7c commit 14d2454

File tree

3 files changed

+206
-45
lines changed

3 files changed

+206
-45
lines changed

pytensor/compile/function/pfunc.py

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from copy import copy
7-
from typing import Optional
7+
from typing import Optional, Sequence, Union, overload
88

99
from pytensor.compile.function.types import Function, UnusedInputError, orig_function
1010
from pytensor.compile.io import In, Out
@@ -15,16 +15,117 @@
1515
from pytensor.graph.fg import FunctionGraph
1616

1717

18+
@overload
1819
def rebuild_collect_shared(
19-
outputs,
20+
outputs: Variable,
2021
inputs=None,
2122
replace=None,
2223
updates=None,
2324
rebuild_strict=True,
2425
copy_inputs_over=True,
2526
no_default_updates=False,
2627
clone_inner_graphs=False,
27-
):
28+
) -> tuple[
29+
list[Variable],
30+
Variable,
31+
tuple[
32+
dict[Variable, Variable],
33+
dict[SharedVariable, Variable],
34+
list[Variable],
35+
list[SharedVariable],
36+
],
37+
]:
38+
...
39+
40+
41+
@overload
42+
def rebuild_collect_shared(
43+
outputs: Sequence[Variable],
44+
inputs=None,
45+
replace=None,
46+
updates=None,
47+
rebuild_strict=True,
48+
copy_inputs_over=True,
49+
no_default_updates=False,
50+
clone_inner_graphs=False,
51+
) -> tuple[
52+
list[Variable],
53+
list[Variable],
54+
tuple[
55+
dict[Variable, Variable],
56+
dict[SharedVariable, Variable],
57+
list[Variable],
58+
list[SharedVariable],
59+
],
60+
]:
61+
...
62+
63+
64+
@overload
65+
def rebuild_collect_shared(
66+
outputs: Out,
67+
inputs=None,
68+
replace=None,
69+
updates=None,
70+
rebuild_strict=True,
71+
copy_inputs_over=True,
72+
no_default_updates=False,
73+
clone_inner_graphs=False,
74+
) -> tuple[
75+
list[Variable],
76+
Out,
77+
tuple[
78+
dict[Variable, Variable],
79+
dict[SharedVariable, Variable],
80+
list[Variable],
81+
list[SharedVariable],
82+
],
83+
]:
84+
...
85+
86+
87+
@overload
88+
def rebuild_collect_shared(
89+
outputs: Sequence[Out],
90+
inputs=None,
91+
replace=None,
92+
updates=None,
93+
rebuild_strict=True,
94+
copy_inputs_over=True,
95+
no_default_updates=False,
96+
clone_inner_graphs=False,
97+
) -> tuple[
98+
list[Variable],
99+
list[Out],
100+
tuple[
101+
dict[Variable, Variable],
102+
dict[SharedVariable, Variable],
103+
list[Variable],
104+
list[SharedVariable],
105+
],
106+
]:
107+
...
108+
109+
110+
def rebuild_collect_shared(
111+
outputs: Union[Sequence[Variable], Variable, Out, Sequence[Out]],
112+
inputs=None,
113+
replace=None,
114+
updates=None,
115+
rebuild_strict=True,
116+
copy_inputs_over=True,
117+
no_default_updates=False,
118+
clone_inner_graphs=False,
119+
) -> tuple[
120+
list[Variable],
121+
Union[list[Variable], Variable, Out, list[Out]],
122+
tuple[
123+
dict[Variable, Variable],
124+
dict[SharedVariable, Variable],
125+
list[Variable],
126+
list[SharedVariable],
127+
],
128+
]:
28129
r"""Replace subgraphs of a computational graph.
29130
30131
It returns a set of dictionaries and lists which collect (partial?)
@@ -260,7 +361,7 @@ def clone_inputs(i):
260361
return (
261362
input_variables,
262363
cloned_outputs,
263-
[clone_d, update_d, update_expr, shared_inputs],
364+
(clone_d, update_d, update_expr, shared_inputs),
264365
)
265366

266367

pytensor/graph/replace.py

Lines changed: 90 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,56 @@
11
from functools import partial
2-
from typing import (
3-
Collection,
4-
Dict,
5-
Iterable,
6-
List,
7-
Optional,
8-
Sequence,
9-
Tuple,
10-
Union,
11-
cast,
12-
)
13-
14-
from pytensor.graph.basic import Constant, Variable, truncated_graph_inputs
2+
from typing import Iterable, Optional, Sequence, Union, cast, overload
3+
4+
from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs
155
from pytensor.graph.fg import FunctionGraph
166

177

8+
ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]
9+
10+
11+
def _format_replace(replace: Optional[ReplaceTypes] = None) -> dict[Variable, Variable]:
12+
items: dict[Variable, Variable]
13+
if isinstance(replace, dict):
14+
# PyLance has issues with type resolution
15+
items = cast(dict[Variable, Variable], replace)
16+
elif isinstance(replace, Iterable):
17+
items = dict(replace)
18+
elif replace is None:
19+
items = {}
20+
else:
21+
raise ValueError(
22+
"replace is neither a dictionary, list, "
23+
f"tuple or None ! The value provided is {replace},"
24+
f"of type {type(replace)}"
25+
)
26+
return items
27+
28+
29+
@overload
30+
def clone_replace(
31+
output: Sequence[Variable],
32+
replace: Optional[ReplaceTypes] = None,
33+
**rebuild_kwds,
34+
) -> list[Variable]:
35+
...
36+
37+
38+
@overload
1839
def clone_replace(
19-
output: Collection[Variable],
40+
output: Variable,
2041
replace: Optional[
21-
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
42+
Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]
2243
] = None,
2344
**rebuild_kwds,
24-
) -> List[Variable]:
45+
) -> Variable:
46+
...
47+
48+
49+
def clone_replace(
50+
output: Union[Sequence[Variable], Variable],
51+
replace: Optional[ReplaceTypes] = None,
52+
**rebuild_kwds,
53+
) -> Union[list[Variable], Variable]:
2554
"""Clone a graph and replace subgraphs within it.
2655
2756
It returns a copy of the initial subgraph with the corresponding
@@ -39,40 +68,49 @@ def clone_replace(
3968
"""
4069
from pytensor.compile.function.pfunc import rebuild_collect_shared
4170

42-
items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]]
43-
if isinstance(replace, dict):
44-
items = list(replace.items())
45-
elif isinstance(replace, (list, tuple)):
46-
items = replace
47-
elif replace is None:
48-
items = []
49-
else:
50-
raise ValueError(
51-
"replace is neither a dictionary, list, "
52-
f"tuple or None ! The value provided is {replace},"
53-
f"of type {type(replace)}"
54-
)
71+
items = list(_format_replace(replace).items())
72+
5573
tmp_replace = [(x, x.type()) for x, y in items]
5674
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
5775
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
5876

5977
# TODO Explain why we call it twice ?!
6078
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
6179

62-
return cast(List[Variable], outs)
80+
return outs
6381

6482

83+
@overload
84+
def graph_replace(
85+
outputs: Variable,
86+
replace: Optional[ReplaceTypes] = None,
87+
*,
88+
strict=True,
89+
) -> Variable:
90+
...
91+
92+
93+
@overload
6594
def graph_replace(
6695
outputs: Sequence[Variable],
67-
replace: Dict[Variable, Variable],
96+
replace: Optional[ReplaceTypes] = None,
97+
*,
98+
strict=True,
99+
) -> list[Variable]:
100+
...
101+
102+
103+
def graph_replace(
104+
outputs: Union[Sequence[Variable], Variable],
105+
replace: Optional[ReplaceTypes] = None,
68106
*,
69107
strict=True,
70-
) -> List[Variable]:
108+
) -> Union[list[Variable], Variable]:
71109
"""Replace variables in ``outputs`` by ``replace``.
72110
73111
Parameters
74112
----------
75-
outputs: Sequence[Variable]
113+
outputs: Union[Sequence[Variable], Variable]
76114
Output graph
77115
replace: Dict[Variable, Variable]
78116
Replace mapping
@@ -83,20 +121,26 @@ def graph_replace(
83121
84122
Returns
85123
-------
86-
List[Variable]
87-
Output graph with subgraphs replaced
124+
Union[Variable, List[Variable]]
125+
Output graph with subgraphs replaced, see function overload for the exact type
88126
89127
Raises
90128
------
91129
ValueError
92-
If some replacemens could not be applied and strict is True
130+
If some replacements could not be applied and strict is True
93131
"""
132+
as_list = False
133+
if not isinstance(outputs, Sequence):
134+
outputs = [outputs]
135+
else:
136+
as_list = True
137+
replace_dict = _format_replace(replace)
94138
# collect minimum graph inputs which is required to compute outputs
95139
# and depend on replacements
96140
# additionally remove constants, they do not matter in clone get equiv
97141
conditions = [
98142
c
99-
for c in truncated_graph_inputs(outputs, replace)
143+
for c in truncated_graph_inputs(outputs, replace_dict)
100144
if not isinstance(c, Constant)
101145
]
102146
# for the function graph we need the clean graph where
@@ -117,7 +161,7 @@ def graph_replace(
117161
# replace the conditions back
118162
fg_replace = {equiv[c]: c for c in conditions}
119163
# add the replacements on top of input mappings
120-
fg_replace.update({equiv[r]: v for r, v in replace.items() if r in equiv})
164+
fg_replace.update({equiv[r]: v for r, v in replace_dict.items() if r in equiv})
121165
# replacements have to be done in reverse topological order so that nested
122166
# expressions get recursively replaced correctly
123167

@@ -126,12 +170,14 @@ def graph_replace(
126170
# So far FunctionGraph does these replacements inplace it is thus unsafe
127171
# apply them using fg.replace, it may change the original graph
128172
if strict:
129-
non_fg_replace = {r: v for r, v in replace.items() if r not in equiv}
173+
non_fg_replace = {r: v for r, v in replace_dict.items() if r not in equiv}
130174
if non_fg_replace:
131175
raise ValueError(f"Some replacements were not used: {non_fg_replace}")
132176
toposort = fg.toposort()
133177

134-
def toposort_key(fg: FunctionGraph, ts, pair):
178+
def toposort_key(
179+
fg: FunctionGraph, ts: list[Apply], pair: tuple[Variable, Variable]
180+
) -> int:
135181
key, _ = pair
136182
if key.owner is not None:
137183
return ts.index(key.owner)
@@ -148,4 +194,7 @@ def toposort_key(fg: FunctionGraph, ts, pair):
148194
reverse=True,
149195
)
150196
fg.replace_all(sorted_replacements, import_missing=True)
151-
return list(fg.outputs)
197+
if as_list:
198+
return list(fg.outputs)
199+
else:
200+
return fg.outputs[0]

tests/graph/test_replace.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,17 @@ def test_graph_replace(self):
169169
# the old reference is still kept
170170
assert oc.owner.inputs[0].owner.inputs[1] is w
171171

172+
def test_non_list_input(self):
173+
x = MyVariable("x")
174+
y = MyVariable("y")
175+
o = MyOp("xyop")(x, y)
176+
new_x = x.clone(name="x_new")
177+
new_y = y.clone(name="y2_new")
178+
# test non list inputs as well
179+
oc = graph_replace(o, {x: new_x, y: new_y})
180+
assert oc.owner.inputs[1] is new_y
181+
assert oc.owner.inputs[0] is new_x
182+
172183
def test_graph_replace_advanced(self):
173184
x = MyVariable("x")
174185
y = MyVariable("y")

0 commit comments

Comments
 (0)