1
+ from pytensor import Variable
1
2
from pytensor .compile .mode import optdb
2
3
from pytensor .graph import Constant , node_rewriter
3
4
from pytensor .graph .replace import vectorize_node
4
5
from pytensor .graph .rewriting .basic import copy_stack_trace , out2in
5
6
from pytensor .tensor .basic import Alloc , ARange , alloc , shape_padleft
6
7
from pytensor .tensor .blockwise import Blockwise
8
+ from pytensor .tensor .elemwise import DimShuffle
7
9
from pytensor .tensor .math import Dot
8
10
from pytensor .tensor .rewriting .basic import (
9
11
register_canonicalize ,
10
12
register_specialize ,
11
13
register_stabilize ,
12
14
)
15
+ from pytensor .tensor .rewriting .uncanonicalize import local_dimshuffle_alloc
16
+ from pytensor .tensor .shape import Reshape
13
17
from pytensor .tensor .subtensor import AdvancedIncSubtensor , AdvancedSubtensor , Subtensor
14
18
15
19
@@ -70,7 +74,7 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
70
74
Dot | Alloc | ARange | Subtensor | AdvancedSubtensor | AdvancedIncSubtensor ,
71
75
):
72
76
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
73
- # These other Ops can't always be trivially vectored at runtime,
77
+ # These other Ops can't always be trivially vectorized at runtime,
74
78
# since their inputs may imply non-rectangular shapes.
75
79
return local_useless_unbatched_blockwise .fn (fgraph , node )
76
80
@@ -86,6 +90,18 @@ def _squeeze_left(x, stop_at_dim: int | None = None):
86
90
return x .squeeze (axis = tuple (range (squeeze_ndim )))
87
91
88
92
93
+ def alloc_or_expand_dims_of_alloc (var : Variable ) -> bool :
94
+ return var .owner and (
95
+ isinstance (var .owner .op , Alloc )
96
+ or (
97
+ isinstance (var .owner .op , DimShuffle )
98
+ and var .owner .inputs [0 ].owner
99
+ and isinstance (var .owner .inputs [0 ].owner .op , Alloc )
100
+ )
101
+ )
102
+
103
+
104
+ @register_canonicalize ("shape_unsafe" )
89
105
@register_specialize ("shape_unsafe" )
90
106
@node_rewriter ([Blockwise ])
91
107
def local_blockwise_alloc (fgraph , node ):
@@ -97,62 +113,73 @@ def local_blockwise_alloc(fgraph, node):
97
113
BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector)
98
114
"""
99
115
100
- if not any (isinstance (inp .owner .op , Alloc ) for inp in node .inputs if inp .owner ):
101
- return None
102
-
103
116
op : Blockwise = node .op # type: ignore
104
117
105
118
batch_ndim = op .batch_ndim (node )
106
119
if not batch_ndim :
107
120
return None
108
121
122
+ if not any (alloc_or_expand_dims_of_alloc (var ) for var in node .inputs ):
123
+ return None
124
+
109
125
new_inputs = []
110
126
batch_shapes = []
111
127
can_push_any_alloc = False
112
128
for inp , inp_sig in zip (node .inputs , op .inputs_sig ):
113
- if inp .owner and isinstance (inp .owner .op , Alloc ):
114
- # Push batch dims from Alloc
115
- value , * shape = inp .owner .inputs
116
-
117
- # Check what to do with the value of the Alloc
118
- squeezed_value = _squeeze_left (value , batch_ndim )
119
- missing_ndim = len (shape ) - value .type .ndim
120
- if (
121
- (((1 ,) * missing_ndim + value .type .broadcastable )[batch_ndim :])
122
- != inp .type .broadcastable [batch_ndim :]
123
- ):
124
- # We still need an Alloc for the core dims
125
- core_shape = shape [batch_ndim :]
126
- # And the batch dims of the squeezed value
127
- squeezed_value_batch_ndim = squeezed_value .type .ndim - len (core_shape )
128
- batch_shape = [
129
- 1 if broadcastable else dim
130
- for broadcastable , dim in zip (
131
- squeezed_value .type .broadcastable [:squeezed_value_batch_ndim ],
132
- tuple (squeezed_value .shape )[:squeezed_value_batch_ndim ],
129
+ if not all (inp .type .broadcastable [:batch_ndim ]):
130
+ if inp .owner and isinstance (inp .owner .op , DimShuffle ):
131
+ # Convert DimShuffle of Alloc to Alloc
132
+ new_inp = local_dimshuffle_alloc .transform (None , inp .owner )
133
+ if new_inp :
134
+ [inp ] = new_inp
135
+
136
+ if inp .owner and isinstance (inp .owner .op , Alloc ):
137
+ # Push batch dims from Alloc
138
+ value , * shape = inp .owner .inputs
139
+
140
+ # Check what to do with the value of the Alloc
141
+ squeezed_value = _squeeze_left (value , batch_ndim )
142
+ missing_ndim = len (shape ) - value .type .ndim
143
+ if (
144
+ (((1 ,) * missing_ndim + value .type .broadcastable )[batch_ndim :])
145
+ != inp .type .broadcastable [batch_ndim :]
146
+ ):
147
+ # We still need an Alloc for the core dims
148
+ core_shape = shape [batch_ndim :]
149
+ # And the batch dims of the squeezed value
150
+ squeezed_value_batch_ndim = squeezed_value .type .ndim - len (
151
+ core_shape
133
152
)
134
- ]
135
- squeezed_value = alloc (squeezed_value , * batch_shape , * core_shape )
136
- if squeezed_value .type .broadcastable == inp .type .broadcastable :
137
- # We can't change anything about this Alloc input
138
- new_inputs .append (inp )
139
- continue
140
-
141
- # We can push batch dims of this Alloc input
142
- batch_shapes .append (
143
- tuple (
144
- 1 if broadcastable else dim
145
- for broadcastable , dim in zip (
146
- inp .type .broadcastable , shape [:batch_ndim ]
153
+ batch_shape = [
154
+ 1 if broadcastable else dim
155
+ for broadcastable , dim in zip (
156
+ squeezed_value .type .broadcastable [
157
+ :squeezed_value_batch_ndim
158
+ ],
159
+ tuple (squeezed_value .shape )[:squeezed_value_batch_ndim ],
160
+ )
161
+ ]
162
+ squeezed_value = alloc (squeezed_value , * batch_shape , * core_shape )
163
+ if squeezed_value .type .broadcastable == inp .type .broadcastable :
164
+ # We can't change anything about this Alloc input
165
+ new_inputs .append (inp )
166
+ continue
167
+
168
+ # We can push batch dims of this Alloc input
169
+ batch_shapes .append (
170
+ tuple (
171
+ 1 if broadcastable else dim
172
+ for broadcastable , dim in zip (
173
+ inp .type .broadcastable , shape [:batch_ndim ]
174
+ )
147
175
)
148
176
)
149
- )
150
- new_inputs . append ( squeezed_value )
151
- can_push_any_alloc = True
177
+ new_inputs . append ( squeezed_value )
178
+ can_push_any_alloc = True
179
+ continue
152
180
153
- else :
154
- # Nothing to do with this input other than removing dummy batch dims
155
- new_inputs .append (_squeeze_left (inp , batch_ndim ))
181
+ # Nothing to do with this input other than removing dummy batch dims
182
+ new_inputs .append (_squeeze_left (inp , batch_ndim ))
156
183
157
184
if not can_push_any_alloc :
158
185
return None
@@ -167,17 +194,15 @@ def local_blockwise_alloc(fgraph, node):
167
194
missing_ndim = old_out_type .ndim - new_out_type .ndim
168
195
batch_shape = ([1 ] * missing_ndim + list (new_outs [0 ].shape ))[:batch_ndim ]
169
196
for i , batch_dims in enumerate (zip (* batch_shapes )): # Transpose shape tuples
197
+ if old_out_type .broadcastable [i ]:
198
+ continue
170
199
for batch_dim in batch_dims :
171
200
if batch_dim == 1 :
172
201
continue
202
+ batch_shape [i ] = batch_dim
173
203
if isinstance (batch_dim , Constant ):
174
204
# Give preference to Constants
175
- batch_shape [i ] = batch_dim
176
205
break
177
- elif old_out_type .broadcastable [i ]:
178
- # Only use non Constant shapes if absolutely necessary
179
- # Otherwise, we use the shape of the non-alloc output
180
- batch_shape [i ] = batch_dim
181
206
182
207
copy_stack_trace (node .outputs , new_outs )
183
208
new_outs = [
@@ -190,3 +215,29 @@ def local_blockwise_alloc(fgraph, node):
190
215
]
191
216
copy_stack_trace (node .outputs , new_outs )
192
217
return new_outs
218
+
219
+
220
+ @register_canonicalize
221
+ @register_specialize
222
+ @node_rewriter ([Blockwise ])
223
+ def local_blockwise_reshape (fgraph , node ):
224
+ """Rewrite away square Blockwise reshapes.
225
+
226
+ Reshape is tricky to vectorize eagerly, because a graph like
227
+ `x.reshape([x.shape[0] * x.shape[1], -1])` has many operations
228
+ that must be vectorized before we arrize at the reshape operation.
229
+
230
+ For the square Reshape case, we must wait for all the intemediate
231
+ operations to be lifted as Allocs
232
+ """
233
+ if not isinstance (node .op .core_op , Reshape ):
234
+ return None
235
+
236
+ x , output_shape = node .inputs
237
+ batch_ndim = node .op .batch_ndim (node )
238
+ if all (output_shape .type .broadcastable [:batch_ndim ]):
239
+ batched_shape = x .shape [:batch_ndim ]
240
+ core_reshape = _squeeze_left (output_shape , batch_ndim )
241
+ new_out = x .reshape ([* tuple (batched_shape ), * tuple (core_reshape )])
242
+ copy_stack_trace (node .outputs [0 ], new_out )
243
+ return [new_out ]
0 commit comments