1
1
from functools import partial
2
- from typing import (
3
- Collection ,
4
- Dict ,
5
- Iterable ,
6
- List ,
7
- Optional ,
8
- Sequence ,
9
- Tuple ,
10
- Union ,
11
- cast ,
12
- )
13
-
14
- from pytensor .graph .basic import Constant , Variable , truncated_graph_inputs
2
+ from typing import Iterable , Optional , Sequence , Union , cast , overload
3
+
4
+ from pytensor .graph .basic import Apply , Constant , Variable , truncated_graph_inputs
15
5
from pytensor .graph .fg import FunctionGraph
16
6
17
7
8
+ ReplaceTypes = Union [Iterable [tuple [Variable , Variable ]], dict [Variable , Variable ]]
9
+
10
+
11
+ def _format_replace (replace : Optional [ReplaceTypes ] = None ) -> dict [Variable , Variable ]:
12
+ items : dict [Variable , Variable ]
13
+ if isinstance (replace , dict ):
14
+ # PyLance has issues with type resolution
15
+ items = cast (dict [Variable , Variable ], replace )
16
+ elif isinstance (replace , Iterable ):
17
+ items = dict (replace )
18
+ elif replace is None :
19
+ items = {}
20
+ else :
21
+ raise ValueError (
22
+ "replace is neither a dictionary, list, "
23
+ f"tuple or None ! The value provided is { replace } ,"
24
+ f"of type { type (replace )} "
25
+ )
26
+ return items
27
+
28
+
29
+ @overload
30
+ def clone_replace (
31
+ output : Sequence [Variable ],
32
+ replace : Optional [ReplaceTypes ] = None ,
33
+ ** rebuild_kwds ,
34
+ ) -> list [Variable ]:
35
+ ...
36
+
37
+
38
+ @overload
18
39
def clone_replace (
19
- output : Collection [ Variable ] ,
40
+ output : Variable ,
20
41
replace : Optional [
21
- Union [Iterable [Tuple [Variable , Variable ]], Dict [Variable , Variable ]]
42
+ Union [Iterable [tuple [Variable , Variable ]], dict [Variable , Variable ]]
22
43
] = None ,
23
44
** rebuild_kwds ,
24
- ) -> List [Variable ]:
45
+ ) -> Variable :
46
+ ...
47
+
48
+
49
+ def clone_replace (
50
+ output : Union [Sequence [Variable ], Variable ],
51
+ replace : Optional [ReplaceTypes ] = None ,
52
+ ** rebuild_kwds ,
53
+ ) -> Union [list [Variable ], Variable ]:
25
54
"""Clone a graph and replace subgraphs within it.
26
55
27
56
It returns a copy of the initial subgraph with the corresponding
@@ -39,40 +68,49 @@ def clone_replace(
39
68
"""
40
69
from pytensor .compile .function .pfunc import rebuild_collect_shared
41
70
42
- items : Union [List [Tuple [Variable , Variable ]], Tuple [Tuple [Variable , Variable ], ...]]
43
- if isinstance (replace , dict ):
44
- items = list (replace .items ())
45
- elif isinstance (replace , (list , tuple )):
46
- items = replace
47
- elif replace is None :
48
- items = []
49
- else :
50
- raise ValueError (
51
- "replace is neither a dictionary, list, "
52
- f"tuple or None ! The value provided is { replace } ,"
53
- f"of type { type (replace )} "
54
- )
71
+ items = list (_format_replace (replace ).items ())
72
+
55
73
tmp_replace = [(x , x .type ()) for x , y in items ]
56
74
new_replace = [(x , y ) for ((_ , x ), (_ , y )) in zip (tmp_replace , items )]
57
75
_ , _outs , _ = rebuild_collect_shared (output , [], tmp_replace , [], ** rebuild_kwds )
58
76
59
77
# TODO Explain why we call it twice ?!
60
78
_ , outs , _ = rebuild_collect_shared (_outs , [], new_replace , [], ** rebuild_kwds )
61
79
62
- return cast ( List [ Variable ], outs )
80
+ return outs
63
81
64
82
83
+ @overload
84
+ def graph_replace (
85
+ outputs : Variable ,
86
+ replace : Optional [ReplaceTypes ] = None ,
87
+ * ,
88
+ strict = True ,
89
+ ) -> Variable :
90
+ ...
91
+
92
+
93
+ @overload
65
94
def graph_replace (
66
95
outputs : Sequence [Variable ],
67
- replace : Dict [Variable , Variable ],
96
+ replace : Optional [ReplaceTypes ] = None ,
97
+ * ,
98
+ strict = True ,
99
+ ) -> list [Variable ]:
100
+ ...
101
+
102
+
103
+ def graph_replace (
104
+ outputs : Union [Sequence [Variable ], Variable ],
105
+ replace : Optional [ReplaceTypes ] = None ,
68
106
* ,
69
107
strict = True ,
70
- ) -> List [ Variable ]:
108
+ ) -> Union [ list [ Variable ], Variable ]:
71
109
"""Replace variables in ``outputs`` by ``replace``.
72
110
73
111
Parameters
74
112
----------
75
- outputs: Sequence[Variable]
113
+ outputs: Union[ Sequence[Variable], Variable]
76
114
Output graph
77
115
replace: Dict[Variable, Variable]
78
116
Replace mapping
@@ -83,20 +121,26 @@ def graph_replace(
83
121
84
122
Returns
85
123
-------
86
- List[Variable]
87
- Output graph with subgraphs replaced
124
+ Union[Variable, List[Variable] ]
125
+ Output graph with subgraphs replaced, see function overload for the exact type
88
126
89
127
Raises
90
128
------
91
129
ValueError
92
- If some replacemens could not be applied and strict is True
130
+ If some replacements could not be applied and strict is True
93
131
"""
132
+ as_list = False
133
+ if not isinstance (outputs , Sequence ):
134
+ outputs = [outputs ]
135
+ else :
136
+ as_list = True
137
+ replace_dict = _format_replace (replace )
94
138
# collect minimum graph inputs which is required to compute outputs
95
139
# and depend on replacements
96
140
# additionally remove constants, they do not matter in clone get equiv
97
141
conditions = [
98
142
c
99
- for c in truncated_graph_inputs (outputs , replace )
143
+ for c in truncated_graph_inputs (outputs , replace_dict )
100
144
if not isinstance (c , Constant )
101
145
]
102
146
# for the function graph we need the clean graph where
@@ -117,7 +161,7 @@ def graph_replace(
117
161
# replace the conditions back
118
162
fg_replace = {equiv [c ]: c for c in conditions }
119
163
# add the replacements on top of input mappings
120
- fg_replace .update ({equiv [r ]: v for r , v in replace .items () if r in equiv })
164
+ fg_replace .update ({equiv [r ]: v for r , v in replace_dict .items () if r in equiv })
121
165
# replacements have to be done in reverse topological order so that nested
122
166
# expressions get recursively replaced correctly
123
167
@@ -126,12 +170,14 @@ def graph_replace(
126
170
# So far FunctionGraph does these replacements inplace it is thus unsafe
127
171
# apply them using fg.replace, it may change the original graph
128
172
if strict :
129
- non_fg_replace = {r : v for r , v in replace .items () if r not in equiv }
173
+ non_fg_replace = {r : v for r , v in replace_dict .items () if r not in equiv }
130
174
if non_fg_replace :
131
175
raise ValueError (f"Some replacements were not used: { non_fg_replace } " )
132
176
toposort = fg .toposort ()
133
177
134
- def toposort_key (fg : FunctionGraph , ts , pair ):
178
+ def toposort_key (
179
+ fg : FunctionGraph , ts : list [Apply ], pair : tuple [Variable , Variable ]
180
+ ) -> int :
135
181
key , _ = pair
136
182
if key .owner is not None :
137
183
return ts .index (key .owner )
@@ -148,4 +194,7 @@ def toposort_key(fg: FunctionGraph, ts, pair):
148
194
reverse = True ,
149
195
)
150
196
fg .replace_all (sorted_replacements , import_missing = True )
151
- return list (fg .outputs )
197
+ if as_list :
198
+ return list (fg .outputs )
199
+ else :
200
+ return fg .outputs [0 ]
0 commit comments