Skip to content

Commit 1399c2f

Browse files
committed
Get rid of expensive Blockwise(Reshape)
1 parent 1f371b7 commit 1399c2f

File tree

5 files changed

+259
-129
lines changed

5 files changed

+259
-129
lines changed

pytensor/tensor/rewriting/blockwise.py

Lines changed: 99 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
from pytensor import Variable
12
from pytensor.compile.mode import optdb
23
from pytensor.graph import Constant, node_rewriter
34
from pytensor.graph.replace import vectorize_node
45
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
56
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
67
from pytensor.tensor.blockwise import Blockwise
8+
from pytensor.tensor.elemwise import DimShuffle
79
from pytensor.tensor.math import Dot
810
from pytensor.tensor.rewriting.basic import (
911
register_canonicalize,
1012
register_specialize,
1113
register_stabilize,
1214
)
15+
from pytensor.tensor.rewriting.uncanonicalize import local_dimshuffle_alloc
16+
from pytensor.tensor.shape import Reshape
1317
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
1418

1519

@@ -70,7 +74,7 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
7074
Dot | Alloc | ARange | Subtensor | AdvancedSubtensor | AdvancedIncSubtensor,
7175
):
7276
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
73-
# These other Ops can't always be trivially vectored at runtime,
77+
# These other Ops can't always be trivially vectorized at runtime,
7478
# since their inputs may imply non-rectangular shapes.
7579
return local_useless_unbatched_blockwise.fn(fgraph, node)
7680

@@ -86,6 +90,18 @@ def _squeeze_left(x, stop_at_dim: int | None = None):
8690
return x.squeeze(axis=tuple(range(squeeze_ndim)))
8791

8892

93+
def alloc_or_expand_dims_of_alloc(var: Variable) -> bool:
94+
return var.owner and (
95+
isinstance(var.owner.op, Alloc)
96+
or (
97+
isinstance(var.owner.op, DimShuffle)
98+
and var.owner.inputs[0].owner
99+
and isinstance(var.owner.inputs[0].owner.op, Alloc)
100+
)
101+
)
102+
103+
104+
@register_canonicalize("shape_unsafe")
89105
@register_specialize("shape_unsafe")
90106
@node_rewriter([Blockwise])
91107
def local_blockwise_alloc(fgraph, node):
@@ -97,62 +113,73 @@ def local_blockwise_alloc(fgraph, node):
97113
BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector)
98114
"""
99115

100-
if not any(isinstance(inp.owner.op, Alloc) for inp in node.inputs if inp.owner):
101-
return None
102-
103116
op: Blockwise = node.op # type: ignore
104117

105118
batch_ndim = op.batch_ndim(node)
106119
if not batch_ndim:
107120
return None
108121

122+
if not any(alloc_or_expand_dims_of_alloc(var) for var in node.inputs):
123+
return None
124+
109125
new_inputs = []
110126
batch_shapes = []
111127
can_push_any_alloc = False
112128
for inp, inp_sig in zip(node.inputs, op.inputs_sig):
113-
if inp.owner and isinstance(inp.owner.op, Alloc):
114-
# Push batch dims from Alloc
115-
value, *shape = inp.owner.inputs
116-
117-
# Check what to do with the value of the Alloc
118-
squeezed_value = _squeeze_left(value, batch_ndim)
119-
missing_ndim = len(shape) - value.type.ndim
120-
if (
121-
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
122-
!= inp.type.broadcastable[batch_ndim:]
123-
):
124-
# We still need an Alloc for the core dims
125-
core_shape = shape[batch_ndim:]
126-
# And the batch dims of the squeezed value
127-
squeezed_value_batch_ndim = squeezed_value.type.ndim - len(core_shape)
128-
batch_shape = [
129-
1 if broadcastable else dim
130-
for broadcastable, dim in zip(
131-
squeezed_value.type.broadcastable[:squeezed_value_batch_ndim],
132-
tuple(squeezed_value.shape)[:squeezed_value_batch_ndim],
129+
if not all(inp.type.broadcastable[:batch_ndim]):
130+
if inp.owner and isinstance(inp.owner.op, DimShuffle):
131+
# Convert DimShuffle of Alloc to Alloc
132+
new_inp = local_dimshuffle_alloc.transform(None, inp.owner)
133+
if new_inp:
134+
[inp] = new_inp
135+
136+
if inp.owner and isinstance(inp.owner.op, Alloc):
137+
# Push batch dims from Alloc
138+
value, *shape = inp.owner.inputs
139+
140+
# Check what to do with the value of the Alloc
141+
squeezed_value = _squeeze_left(value, batch_ndim)
142+
missing_ndim = len(shape) - value.type.ndim
143+
if (
144+
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
145+
!= inp.type.broadcastable[batch_ndim:]
146+
):
147+
# We still need an Alloc for the core dims
148+
core_shape = shape[batch_ndim:]
149+
# And the batch dims of the squeezed value
150+
squeezed_value_batch_ndim = squeezed_value.type.ndim - len(
151+
core_shape
133152
)
134-
]
135-
squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape)
136-
if squeezed_value.type.broadcastable == inp.type.broadcastable:
137-
# We can't change anything about this Alloc input
138-
new_inputs.append(inp)
139-
continue
140-
141-
# We can push batch dims of this Alloc input
142-
batch_shapes.append(
143-
tuple(
144-
1 if broadcastable else dim
145-
for broadcastable, dim in zip(
146-
inp.type.broadcastable, shape[:batch_ndim]
153+
batch_shape = [
154+
1 if broadcastable else dim
155+
for broadcastable, dim in zip(
156+
squeezed_value.type.broadcastable[
157+
:squeezed_value_batch_ndim
158+
],
159+
tuple(squeezed_value.shape)[:squeezed_value_batch_ndim],
160+
)
161+
]
162+
squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape)
163+
if squeezed_value.type.broadcastable == inp.type.broadcastable:
164+
# We can't change anything about this Alloc input
165+
new_inputs.append(inp)
166+
continue
167+
168+
# We can push batch dims of this Alloc input
169+
batch_shapes.append(
170+
tuple(
171+
1 if broadcastable else dim
172+
for broadcastable, dim in zip(
173+
inp.type.broadcastable, shape[:batch_ndim]
174+
)
147175
)
148176
)
149-
)
150-
new_inputs.append(squeezed_value)
151-
can_push_any_alloc = True
177+
new_inputs.append(squeezed_value)
178+
can_push_any_alloc = True
179+
continue
152180

153-
else:
154-
# Nothing to do with this input other than removing dummy batch dims
155-
new_inputs.append(_squeeze_left(inp, batch_ndim))
181+
# Nothing to do with this input other than removing dummy batch dims
182+
new_inputs.append(_squeeze_left(inp, batch_ndim))
156183

157184
if not can_push_any_alloc:
158185
return None
@@ -167,17 +194,15 @@ def local_blockwise_alloc(fgraph, node):
167194
missing_ndim = old_out_type.ndim - new_out_type.ndim
168195
batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim]
169196
for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples
197+
if old_out_type.broadcastable[i]:
198+
continue
170199
for batch_dim in batch_dims:
171200
if batch_dim == 1:
172201
continue
202+
batch_shape[i] = batch_dim
173203
if isinstance(batch_dim, Constant):
174204
# Give preference to Constants
175-
batch_shape[i] = batch_dim
176205
break
177-
elif old_out_type.broadcastable[i]:
178-
# Only use non Constant shapes if absolutely necessary
179-
# Otherwise, we use the shape of the non-alloc output
180-
batch_shape[i] = batch_dim
181206

182207
copy_stack_trace(node.outputs, new_outs)
183208
new_outs = [
@@ -190,3 +215,29 @@ def local_blockwise_alloc(fgraph, node):
190215
]
191216
copy_stack_trace(node.outputs, new_outs)
192217
return new_outs
218+
219+
220+
@register_canonicalize
221+
@register_specialize
222+
@node_rewriter([Blockwise])
223+
def local_blockwise_reshape(fgraph, node):
224+
"""Rewrite away square Blockwise reshapes.
225+
226+
Reshape is tricky to vectorize eagerly, because a graph like
227+
`x.reshape([x.shape[0] * x.shape[1], -1])` has many operations
228+
that must be vectorized before we arrize at the reshape operation.
229+
230+
For the square Reshape case, we must wait for all the intemediate
231+
operations to be lifted as Allocs
232+
"""
233+
if not isinstance(node.op.core_op, Reshape):
234+
return None
235+
236+
x, output_shape = node.inputs
237+
batch_ndim = node.op.batch_ndim(node)
238+
if all(output_shape.type.broadcastable[:batch_ndim]):
239+
batched_shape = x.shape[:batch_ndim]
240+
core_reshape = _squeeze_left(output_shape, batch_ndim)
241+
new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
242+
copy_stack_trace(node.outputs[0], new_out)
243+
return [new_out]

0 commit comments

Comments
 (0)