Skip to content

Commit d5a054d

Browse files
committed
Generalize lift of Subtensor over Elemwise
Split off Subtensor of Unbroadcast into its own rewrite
1 parent f1db1bd commit d5a054d

File tree

2 files changed

+153
-174
lines changed

2 files changed

+153
-174
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 68 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -108,73 +108,79 @@ def local_subtensor_of_dot(fgraph, node):
108108
return [r]
109109

110110

111-
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
112-
@register_canonicalize("fast_compile")
111+
@register_canonicalize("shape_unsafe")
112+
@register_specialize("shape_unsafe")
113113
@node_rewriter([Subtensor])
114-
def local_subtensor_lift(fgraph, node):
114+
def local_subtensor_of_elemwise(fgraph, node):
115+
"""Lift a Subtensor through an Elemwise and its implicit broadcasting behavior.
116+
117+
exp(x)[:, 0] -> exp(x[:, 0])
118+
add(x, y)[0] -> add(x[0], y[0])
119+
add(x[None], y)[2] -> add(x, y[2])
115120
"""
116-
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
121+
elem, *idx = node.inputs
117122

118-
Handles the following unary ops:
119-
elemwise(x,...)[idx] -> elemwise(x[idx],...)
120-
when x,... are broadcasted scalar or not broadcasted at all
123+
if not (elem.owner and isinstance(elem.owner.op, Elemwise)):
124+
return None
121125

122-
"""
123-
if isinstance(node.op, Subtensor):
124-
u = node.inputs[0]
125-
if u.owner is None or len(fgraph.clients[u]) > 1:
126-
return False
127-
128-
if isinstance(u.owner.op, Elemwise) and len(u.owner.inputs) == 1:
129-
idx = node.inputs[1:]
130-
x_idx = node.op(u.owner.inputs[0], *idx)
131-
# Copy over previous output stacktrace
132-
copy_stack_trace(node.outputs, x_idx)
133-
ret = u.owner.op(x_idx)
134-
# Copy over previous output stacktrace
135-
# and stacktrace from previous unary operation
136-
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
137-
return [ret]
126+
if len(fgraph.clients[elem]) > 1:
127+
# Elemwise output is used beyond the Subtensor.
128+
# Get out to avoid repeated computations
129+
return None
138130

139-
if isinstance(u.owner.op, Elemwise):
140-
new_inputs = []
141-
if all(sum(i.type.broadcastable) == 0 for i in u.owner.inputs):
142-
# There is no broadcastable in the inputs
143-
idx = node.inputs[1:]
144-
new_inputs = [node.op(i, *idx) for i in u.owner.inputs]
145-
# Copy over previous output stacktrace
146-
copy_stack_trace(node.outputs[0], new_inputs)
147-
148-
ret = u.owner.op(*new_inputs)
149-
# Copy over previous output stacktrace
150-
# and stacktrace from previous unary operation
151-
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
152-
return [ret]
153-
elif all(sum(i.type.broadcastable) in [i.ndim, 0] for i in u.owner.inputs):
154-
# There is no broadcastable in the inputs or it is scalar
155-
idx = node.inputs[1:]
156-
new_inputs = []
157-
for i in u.owner.inputs:
158-
if sum(i.type.broadcastable) == 0:
159-
new_inputs.append(node.op(i, *idx))
160-
else:
161-
# If the subtensor remove some dims, we must
162-
# lower the number of dimensions of this scalar.
163-
if node.outputs[0].ndim == i.ndim:
164-
new_inputs.append(i)
165-
else:
166-
new_inputs.append(
167-
i.dimshuffle(["x"] * node.outputs[0].ndim)
168-
)
169-
170-
# Copy over previous output stacktrace
171-
copy_stack_trace(node.outputs[0], new_inputs)
172-
173-
ret = u.owner.op(*new_inputs)
174-
# Copy over previous output stacktrace
175-
# and stacktrace from previous unary operation
176-
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
177-
return [ret]
131+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
132+
133+
elem_inputs = elem.owner.inputs
134+
elem_bcast = elem.type.broadcastable
135+
if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs):
136+
# No need to worry about implicit broadcasting.
137+
indexed_inputs = [inp[idx_tuple] for inp in elem_inputs]
138+
139+
else:
140+
# The original indices may not make sense on some of the broadcasted dimensions
141+
new_idxs = [list(idx_tuple) for _ in elem_inputs]
142+
for dim, (dim_idx, dim_bcast_out, *dim_bcast_inputs) in enumerate(
143+
zip(
144+
idx_tuple,
145+
elem_bcast,
146+
*(inp.type.broadcastable for inp in elem_inputs),
147+
# Indices can be shorter than input ndims
148+
strict=False,
149+
)
150+
):
151+
if is_full_slice(dim_idx):
152+
# Full slice can be safely applied to all inputs
153+
continue
154+
155+
if all(dim_bcast_inp == elem_bcast for dim_bcast_inp in dim_bcast_inputs):
156+
# This dim is not broadcasted for any of the inputs, original index can be applied to all inputs
157+
continue
158+
159+
# Some dims are broadcasted, so we need to adapt their indices
160+
# Slice indexing keeps the dimension, so we use a full slice for broadcasted inputs
161+
# Integer indexing drops the dimension, so we index by zero for the broadcsated inputs
162+
safe_bcast_dim_idx = slice(None) if isinstance(dim_idx, slice) else 0
163+
for inp_idx, dim_bcast_inp in zip(new_idxs, dim_bcast_inputs, strict=True):
164+
if dim_bcast_inp:
165+
inp_idx[dim] = safe_bcast_dim_idx
166+
167+
indexed_inputs = [
168+
inp[tuple(new_idx)]
169+
for inp, new_idx in zip(elem_inputs, new_idxs, strict=True)
170+
]
171+
172+
[old_out] = node.outputs
173+
174+
# Copy stack trace to new inputs
175+
[copy_stack_trace(old_out, new_inp) for new_inp in indexed_inputs]
176+
177+
# Define elemwise operation on indexed inputs
178+
new_out = elem.owner.op(*indexed_inputs)
179+
180+
# Copy stack trace to new output
181+
copy_stack_trace([old_out, *node.inputs], new_out)
182+
183+
return [new_out]
178184

179185

180186
@register_canonicalize("shape_unsafe")

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 85 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22
import pytest
3-
import unittest_tools as utt
43

54
from pytensor import (
65
Mode,
@@ -25,13 +24,11 @@
2524
from pytensor.tensor import (
2625
add,
2726
exp,
28-
inplace,
2927
iscalar,
3028
iscalars,
3129
lscalar,
3230
lscalars,
3331
matrix,
34-
scalar,
3532
shape,
3633
slicetype,
3734
specify_shape,
@@ -43,6 +40,7 @@
4340
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4441
from pytensor.tensor.rewriting.subtensor_lift import (
4542
local_subtensor_make_vector,
43+
local_subtensor_of_elemwise,
4644
local_subtensor_shape_constant,
4745
)
4846
from pytensor.tensor.shape import SpecifyShape, _shape
@@ -58,22 +56,8 @@
5856
NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None)
5957

6058

61-
class TestLocalSubtensorLift:
62-
def test_basic(self):
63-
# basic test that the Op works
64-
x = matrix("x")
65-
f = function([x], exp(x)[0], mode=mode_opt)
66-
67-
# Check stacktrace was copied over correctly after opt was applied
68-
assert check_stack_trace(f, ops_to_check="all")
69-
70-
prog = f.maker.fgraph.toposort()
71-
assert isinstance(prog[0].op, Subtensor) # first subtensor
72-
assert prog[1].op == exp
73-
assert len(prog) == 2
74-
f([[0, 1], [2, 3]]) # let debugmode test something
75-
76-
def test_basic_1(self):
59+
class TestLocalSubtensorOfElemwise:
60+
def test_unary_multiple_clients(self):
7761
# as test0, but we reuse the output of the elemwise
7862
# So we should not lift the subtensor
7963
x = matrix("x")
@@ -87,85 +71,16 @@ def test_basic_1(self):
8771
assert isinstance(prog[1].op, Subtensor) # first subtensor
8872
assert isinstance(prog[2].op, DeepCopyOp)
8973
assert len(prog) == 3
90-
f([[0, 1], [2, 3]]) # let debugmode test something
91-
92-
def test_basic_2(self):
93-
# basic test that the optimization work with scalar broadcasted
94-
x = matrix("x")
95-
y = scalar("y")
96-
z = matrix("z")
97-
f = function([x, y, z], exp(x + y + z)[0], mode=mode_opt)
98-
99-
prog = f.maker.fgraph.toposort()
100-
assert isinstance(prog[0].op, Subtensor)
101-
assert isinstance(prog[1].op, DimShuffle)
102-
assert isinstance(prog[2].op, Subtensor)
103-
assert isinstance(prog[3].op.scalar_op, ps.Composite) # Composite{add,add}
104-
assert len(prog) == 4
105-
106-
# Check stacktrace was copied over correctly after opt was applied
107-
assert check_stack_trace(f, ops_to_check=[Subtensor])
108-
109-
# let debugmode test something
110-
f([[0, 1], [2, 3]], 4, [[4, 5], [6, 7]])
111-
112-
def test_basic_3(self):
113-
# as 1, but take a slice
114-
x = matrix("x")
115-
y = scalar("y")
116-
z = matrix("z")
117-
f = function([x, y, z], exp(x + y + z)[0:2], mode=mode_opt)
118-
119-
prog = f.maker.fgraph.toposort()
120-
assert isinstance(prog[0].op, Subtensor)
121-
assert isinstance(prog[1].op, DimShuffle)
122-
assert isinstance(prog[2].op, Subtensor)
123-
assert isinstance(prog[3].op.scalar_op, ps.Composite) # Composite{add,add}
124-
assert len(prog) == 4
125-
126-
# Check stacktrace was copied over correctly after opt was applied
127-
assert check_stack_trace(f, ops_to_check=[Subtensor])
128-
129-
# let debugmode test something
130-
f([[0, 1], [2, 3]], 4, [[4, 5], [6, 7]])
131-
132-
def test_basic_4(self):
133-
# basic test that the optimization does work with broadcasting
134-
# for unary elemwise.
135-
y = vector("y")
136-
f = function([y], exp(y.dimshuffle(0, "x"))[0], mode=mode_opt)
137-
138-
# Check stacktrace was copied over correctly after opt was applied
139-
assert check_stack_trace(f, ops_to_check="all")
140-
141-
prog = f.maker.fgraph.toposort()
142-
assert isinstance(prog[0].op, Subtensor)
143-
assert isinstance(prog[1].op, DimShuffle)
144-
assert prog[2].op == exp
145-
assert len(prog) == 3
146-
f([4, 5]) # let debugmode test something
147-
148-
@utt.assertFailure_fast
149-
def test_basic_5(self):
150-
# basic test that the optimization doesn't work with broadcasting
151-
# ... It *could* be extended to,
152-
# ... but right now it doesn't, so it shouldn't try.
153-
x = matrix("x")
154-
y = vector("y")
155-
f = function([x, y], exp(x + y)[0], mode=mode_opt)
15674

157-
# Opt doesn't apply, so no need for check_stack_trace
158-
# assert check_stack_trace(f, ops_to_check='all')
159-
160-
prog = f.maker.fgraph.toposort()
161-
assert isinstance(prog[0].op, DimShuffle)
162-
assert prog[1].op == add
163-
assert isinstance(prog[2].op, Subtensor) # first subtensor
164-
assert prog[3].op == inplace.exp_inplace
165-
assert len(prog) == 4
166-
f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something
75+
x_test = [[0, 1], [2, 3]]
76+
res1, res2 = f(x_test)
77+
np.testing.assert_allclose(
78+
res1,
79+
np.exp(x_test)[0],
80+
)
81+
np.testing.assert_allclose(res2, np.exp(x_test))
16782

168-
def test_basic_6(self):
83+
def test_multinary_multiple_clients(self):
16984
# test that we don't lift when we reuse the output of the
17085
# elemwise for other computation.
17186
x = matrix("x")
@@ -181,26 +96,84 @@ def test_basic_6(self):
18196
# first subtensor
18297
assert isinstance(prog[2].op, Subtensor)
18398
assert len(prog) == 3
184-
f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something
18599

186-
def test_basic_7(self):
187-
# basic test that the optimization works with a scalar as input,
188-
# and a scalar as output (no broadcasting of the scalar needed).
189-
# The optimization used to fail and display an ERROR message.
100+
x_test = np.array([[0, 1], [2, 3]]).astype(x.dtype)
101+
y_test = np.array([4, 5]).astype(y.dtype)
102+
res1, res2 = f(x_test, y_test)
103+
np.testing.assert_allclose(
104+
res1,
105+
np.exp(x_test + y_test)[0],
106+
)
107+
np.testing.assert_allclose(
108+
res2,
109+
np.exp(x_test + y_test) + x_test,
110+
)
111+
112+
@pytest.mark.parametrize(
113+
"original_fn, expected_fn",
114+
[
115+
# Unary integer indexing
116+
(lambda x, y: exp(x)[0], lambda x, y: exp(x[0])),
117+
# Unary integer with expand_dims
118+
(lambda x, y: exp(x[:, None])[0], lambda x, y: exp(x[0][None])),
119+
# Integer indexing on non-broadcastable dimension
120+
(lambda x, y: add(x, y)[0], lambda x, y: add(x[0], y[0])),
121+
# Slice indexing on non-broadcastable dimension
122+
(lambda x, y: add(x, y)[1:], lambda x, y: add(x[1:], y[1:])),
123+
# Integer indexing on broacastable dimension
124+
(lambda x, y: add(x[None], y[None])[0], lambda x, y: add(x, y)),
125+
(lambda x, y: add(x[None], y[None])[0, 1], lambda x, y: add(x[1], y[1])),
126+
(
127+
lambda x, y: add(x[None, :], y[:, None])[2],
128+
lambda x, y: add(x, y[2][None]),
129+
),
130+
(
131+
lambda x, y: add(x[:, None], y[None, :])[:, 2],
132+
lambda x, y: add(x, y[2][None]),
133+
),
134+
# Slice indexing on broadcastable dimension
135+
(
136+
lambda x, y: add(x[None], y[None])[1:],
137+
lambda x, y: add(x[None][1:], y[None][1:]),
138+
),
139+
(
140+
lambda x, y: add(x[None, :], y[:, None])[1:],
141+
lambda x, y: add(x[None, :], y[1:][:, None]),
142+
),
143+
],
144+
)
145+
def test_local_subtensor_of_elemwise(self, original_fn, expected_fn):
146+
rng = np.random.default_rng(257)
147+
x = pt.matrix("x", shape=(5, 3))
148+
y = pt.matrix("y", shape=(5, 3))
149+
x_test = rng.normal(size=x.type.shape).astype(x.dtype)
150+
y_test = rng.normal(size=y.type.shape).astype(y.dtype)
151+
152+
out = original_fn(x, y)
153+
expected_opt_out = expected_fn(x, y)
154+
opt_out = rewrite_graph(out)
155+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
156+
[expected_opt_out, opt_out], print_type=True
157+
)
158+
eval_kwargs = dict(mode=NO_OPTIMIZATION_MODE, on_unused_input="ignore")
159+
np.testing.assert_allclose(
160+
opt_out.eval({x: x_test, y: y_test}, **eval_kwargs),
161+
out.eval({x: x_test, y: y_test}, **eval_kwargs),
162+
)
190163

191-
x = vector("x")
192-
y = scalar("y")
193-
f = function([x, y], exp(x + y)[0], mode=mode_opt)
164+
def test_local_subtensor_of_elemwise_multiple_clients(self):
165+
x = pt.matrix("x", shape=(5, 3))
166+
y = pt.matrix("y", shape=(5, 3))
167+
out1 = add(x, y)
168+
out2 = out1[0]
194169

195-
# Check stacktrace was copied over correctly after opt was applied
196-
assert check_stack_trace(f, ops_to_check=Subtensor)
170+
# Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
171+
fgraph = FunctionGraph([x, y], [out1, out2], clone=False)
172+
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is None
197173

198-
prog = f.maker.fgraph.toposort()
199-
assert isinstance(prog[0].op, Subtensor)
200-
# Composite{add,exp}
201-
assert isinstance(prog[1].op.scalar_op, ps.Composite)
202-
assert len(prog) == 2
203-
f([1, 2, 3], 4) # let debugmode test something
174+
# Otherwise it should work
175+
fgraph.remove_output(0)
176+
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None
204177

205178

206179
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)