Skip to content

Commit fe20a66

Browse files
committed
Get rid of Reshape again, now that we Vectorize join
1 parent 99e7eab commit fe20a66

File tree

3 files changed

+138
-31
lines changed

3 files changed

+138
-31
lines changed

pytensor/tensor/basic.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,21 +1681,22 @@ def do_constant_folding(self, fgraph, node):
16811681
return False
16821682

16831683
for client, idx in clients:
1684-
if isinstance(client.op, Output):
1684+
client_op = client.op
1685+
if isinstance(client_op, Output):
16851686
# If the output is a constant, it will have to be deepcopied
16861687
# each time the function is called. So we do not fold.
16871688
return False
1688-
# Allow alloc to be lifted out of Elemwise before constant folding it
1689-
elif isinstance(client.op, Elemwise):
1690-
return None
1689+
# Op's through which Alloc can be lifted
1690+
elif isinstance(client_op, Elemwise | DimShuffle | Alloc | Join):
1691+
return False
16911692
# Same for Blockwise, unless it has no batch_dims
1692-
elif isinstance(client.op, Blockwise) and client.op.batch_ndim(client):
1693-
return None
1693+
elif isinstance(client_op, Blockwise) and client.op.batch_ndim(client):
1694+
return False
16941695
elif (
16951696
# The following ops work inplace of their input id 0.
16961697
idx == 0
16971698
and isinstance(
1698-
client.op,
1699+
client_op,
16991700
pytensor.tensor.subtensor.IncSubtensor
17001701
| pytensor.tensor.subtensor.AdvancedIncSubtensor1
17011702
| pytensor.tensor.subtensor.AdvancedIncSubtensor

pytensor/tensor/rewriting/basic.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
TensorFromScalar,
5353
alloc,
5454
as_tensor_variable,
55+
atleast_Nd,
5556
cast,
5657
extract_constant,
5758
fill,
@@ -1219,3 +1220,124 @@ def local_merge_alloc(fgraph, node):
12191220

12201221

12211222
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")
1223+
1224+
1225+
@register_specialize
1226+
@node_rewriter([DimShuffle])
1227+
def local_dimshuffle_alloc(fgraph, node):
1228+
"""
1229+
Lift DimShuffle through Alloc
1230+
1231+
dimshuffle{x, 0, 1}(alloc([3 4], 3, 2) => alloc([3 4], 1, 3, 2)
1232+
"""
1233+
alloc_out = node.inputs[0]
1234+
alloc_node = alloc_out.owner
1235+
if not (alloc_node and isinstance(alloc_node.op, Alloc)):
1236+
return
1237+
1238+
ds_op = node.op
1239+
value, *alloc_shape = alloc_node.inputs
1240+
1241+
# Add implicit dimensions of value
1242+
value = atleast_Nd(value, n=len(alloc_shape))
1243+
1244+
# Dimshuffle value and alloc_shape
1245+
ds_value = value.dimshuffle(ds_op.new_order)
1246+
ds_alloc_shape = [alloc_shape[i] for i in ds_op.shuffle]
1247+
for dim in ds_op.augment:
1248+
ds_alloc_shape.insert(dim, 1)
1249+
1250+
return [alloc(ds_value, *ds_alloc_shape)]
1251+
1252+
1253+
@register_specialize("shape_unsafe")
1254+
@node_rewriter([Join])
1255+
def local_join_of_alloc(fgraph, node):
1256+
"""Rewrite a Join of Alloc nodes to an Alloc of the Join nodes."""
1257+
axis, *tensors = node.inputs
1258+
1259+
if len(tensors) < 2:
1260+
# Let other rewrite handle the useless Join
1261+
return
1262+
1263+
if not isinstance(axis, Constant):
1264+
return
1265+
1266+
core_tensors = []
1267+
alloc_shapes = []
1268+
for tensor in tensors:
1269+
if tensor.owner is None:
1270+
print(" > failed no owner")
1271+
return
1272+
1273+
# tensor = expand_dims_to_alloc(tensor)
1274+
if not isinstance(tensor.owner.op, Alloc):
1275+
return
1276+
1277+
value, *shape = tensor.owner.inputs
1278+
# Introduce explicit batch dims
1279+
value = atleast_Nd(value, n=len(shape))
1280+
core_tensors.append(value)
1281+
alloc_shapes.append(shape)
1282+
1283+
# Find which allocated dimensions can be lifted
1284+
# Axis can never be lifted
1285+
# Non-axis allocated dimensions can be lifted if they are all broadcastable
1286+
[out] = node.outputs
1287+
axis = axis.data
1288+
1289+
broadcasted_dims = list(
1290+
zip(
1291+
*(
1292+
[
1293+
bef and not aft
1294+
for bef, aft in zip(
1295+
core_tensor.type.broadcastable,
1296+
tensor.type.broadcastable,
1297+
strict=True,
1298+
)
1299+
]
1300+
for core_tensor, tensor in zip(core_tensors, tensors, strict=True)
1301+
)
1302+
)
1303+
)
1304+
1305+
lifteable_alloc_dims = {
1306+
dim
1307+
for dim in range(out.type.ndim)
1308+
if dim != axis and all(broadcasted_dims[dim])
1309+
}
1310+
1311+
if not lifteable_alloc_dims:
1312+
return
1313+
1314+
# Lift the allocated dimensions
1315+
new_tensors = []
1316+
for core_tensor, alloc_shape in zip(core_tensors, alloc_shapes):
1317+
pre_join_shape = [
1318+
1 if i in lifteable_alloc_dims else alloc_dim
1319+
for i, alloc_dim in enumerate(alloc_shape)
1320+
]
1321+
new_tensor = alloc(core_tensor, *pre_join_shape)
1322+
copy_stack_trace(tensor, new_tensor)
1323+
new_tensors.append(new_tensor)
1324+
1325+
new_join = node.op(axis, *new_tensors)
1326+
copy_stack_trace(node.outputs[0], new_join)
1327+
1328+
# Reintroduce the lifted dims
1329+
post_join_shape = []
1330+
for i, alloc_dims in enumerate(zip(*alloc_shapes)):
1331+
if i == axis:
1332+
# The alloc dim along the axis is the sum of all the pre-join alloc dims
1333+
post_join_shape.append(add(*alloc_dims))
1334+
else:
1335+
# Otherwise the shapes should all match. We prioritize constants if any
1336+
for best_alloc_dim in alloc_dims:
1337+
if isinstance(best_alloc_dim, Constant):
1338+
break
1339+
post_join_shape.append(best_alloc_dim)
1340+
1341+
new_out = alloc(new_join, *post_join_shape)
1342+
copy_stack_trace(node.outputs[0], new_out)
1343+
return [new_out]

pytensor/tensor/rewriting/blockwise.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
1-
from pytensor import Variable
21
from pytensor.compile.mode import optdb
32
from pytensor.graph import Constant, node_rewriter
43
from pytensor.graph.replace import vectorize_node
54
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
65
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
76
from pytensor.tensor.blockwise import Blockwise
8-
from pytensor.tensor.elemwise import DimShuffle
97
from pytensor.tensor.math import Dot
108
from pytensor.tensor.rewriting.basic import (
119
register_canonicalize,
1210
register_specialize,
1311
register_stabilize,
1412
)
15-
from pytensor.tensor.rewriting.uncanonicalize import local_dimshuffle_alloc
1613
from pytensor.tensor.shape import Reshape
1714
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
1815

@@ -71,7 +68,13 @@ def local_useless_unbatched_blockwise(fgraph, node):
7168
def local_eager_useless_unbatched_blockwise(fgraph, node):
7269
if isinstance(
7370
node.op.core_op,
74-
Dot | Alloc | ARange | Subtensor | AdvancedSubtensor | AdvancedIncSubtensor,
71+
Dot
72+
| Alloc
73+
| ARange
74+
| Subtensor
75+
| AdvancedSubtensor
76+
| AdvancedIncSubtensor
77+
| Reshape,
7578
):
7679
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
7780
# These other Ops can't always be trivially vectorized at runtime,
@@ -90,18 +93,6 @@ def _squeeze_left(x, stop_at_dim: int | None = None):
9093
return x.squeeze(axis=tuple(range(squeeze_ndim)))
9194

9295

93-
def alloc_or_expand_dims_of_alloc(var: Variable) -> bool:
94-
return var.owner and (
95-
isinstance(var.owner.op, Alloc)
96-
or (
97-
isinstance(var.owner.op, DimShuffle)
98-
and var.owner.inputs[0].owner
99-
and isinstance(var.owner.inputs[0].owner.op, Alloc)
100-
)
101-
)
102-
103-
104-
@register_canonicalize("shape_unsafe")
10596
@register_specialize("shape_unsafe")
10697
@node_rewriter([Blockwise])
10798
def local_blockwise_alloc(fgraph, node):
@@ -119,20 +110,14 @@ def local_blockwise_alloc(fgraph, node):
119110
if not batch_ndim:
120111
return None
121112

122-
if not any(alloc_or_expand_dims_of_alloc(var) for var in node.inputs):
113+
if not any(var.owner and isinstance(var.owner.op, Alloc) for var in node.inputs):
123114
return None
124115

125116
new_inputs = []
126117
batch_shapes = []
127118
can_push_any_alloc = False
128119
for inp, inp_sig in zip(node.inputs, op.inputs_sig):
129120
if not all(inp.type.broadcastable[:batch_ndim]):
130-
if inp.owner and isinstance(inp.owner.op, DimShuffle):
131-
# Convert DimShuffle of Alloc to Alloc
132-
new_inp = local_dimshuffle_alloc.transform(None, inp.owner)
133-
if new_inp:
134-
[inp] = new_inp
135-
136121
if inp.owner and isinstance(inp.owner.op, Alloc):
137122
# Push batch dims from Alloc
138123
value, *shape = inp.owner.inputs
@@ -217,7 +202,6 @@ def local_blockwise_alloc(fgraph, node):
217202
return new_outs
218203

219204

220-
@register_canonicalize
221205
@register_specialize
222206
@node_rewriter([Blockwise])
223207
def local_blockwise_reshape(fgraph, node):

0 commit comments

Comments
 (0)