Skip to content

Commit 43cad30

Browse files
committed
Lift Subtensor over CAReduce
1 parent d5a054d commit 43cad30

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytensor import Variable
66
from pytensor.graph import Constant, node_rewriter
77
from pytensor.graph.rewriting.basic import copy_stack_trace
8+
from pytensor.npy_2_compat import normalize_axis_tuple
89
from pytensor.scalar import basic as ps
910
from pytensor.tensor.basic import (
1011
Alloc,
@@ -15,7 +16,7 @@
1516
get_underlying_scalar_constant_value,
1617
register_infer_shape,
1718
)
18-
from pytensor.tensor.elemwise import DimShuffle, Elemwise
19+
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
1920
from pytensor.tensor.exceptions import NotScalarConstantError
2021
from pytensor.tensor.extra_ops import squeeze
2122
from pytensor.tensor.math import Dot, ceil_intdiv, dot
@@ -183,6 +184,63 @@ def local_subtensor_of_elemwise(fgraph, node):
183184
return [new_out]
184185

185186

187+
@register_canonicalize
188+
@register_specialize
189+
@node_rewriter([Subtensor])
190+
def local_subtensor_of_reduce(fgraph, node):
191+
"""Lift a Subtensor through a CAReduce Op.
192+
193+
For now rewrite is restricted to single axis of reduction, for simplicity.
194+
195+
sum(x, axis=1)[0] -> sum(x[0], axis=0)
196+
sum(x, axis=1)[1:] -> sum(x[1:], axis=1)
197+
sum(x, axis=0)[0] -> sum(x[:, 0], axis=0)
198+
sum(x, axis=0)[1:] -> sum(x[:, 1:], axis=0)
199+
200+
"""
201+
red, *idx = node.inputs
202+
203+
if not (red.owner and isinstance(red.owner.op, CAReduce)):
204+
return None
205+
206+
if len(fgraph.clients[red]) > 1:
207+
# Don't apply rewrite if another node requires the full reduction
208+
return None
209+
210+
[x] = red.owner.inputs
211+
axis = red.owner.op.axis
212+
213+
if axis is None:
214+
axis = tuple(range(x.type.ndim))
215+
216+
# TODO: Allow reduction across multiple axis
217+
if len(axis) != 1:
218+
return None
219+
220+
[axis] = normalize_axis_tuple(axis, x.ndim)
221+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
222+
223+
# Index input of reduction.
224+
new_idxs = list(idx_tuple)
225+
if axis < len(idx_tuple):
226+
# When there are indexes beyond the axis of reduction, we need to shift them with None slices.
227+
new_idxs.insert(axis, slice(None))
228+
x_sub = x[tuple(new_idxs)]
229+
230+
[old_out] = node.outputs
231+
copy_stack_trace(old_out, x_sub)
232+
233+
# Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing)
234+
axis -= len(
235+
[idx_item for idx_item in idx_tuple[:axis] if not isinstance(idx_item, slice)]
236+
)
237+
238+
# Apply reduction to indexed input
239+
out = type(red.owner.op)(axis=axis)(x_sub)
240+
copy_stack_trace(old_out, out)
241+
return [out]
242+
243+
186244
@register_canonicalize("shape_unsafe")
187245
@register_specialize("shape_unsafe")
188246
@node_rewriter([Subtensor])

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from pytensor.tensor.basic import MakeVector, expand_dims, make_vector
4040
from pytensor.tensor.elemwise import DimShuffle, Elemwise
41+
from pytensor.tensor.math import sum as pt_sum
4142
from pytensor.tensor.rewriting.subtensor_lift import (
4243
local_subtensor_make_vector,
4344
local_subtensor_of_elemwise,
@@ -176,6 +177,40 @@ def test_local_subtensor_of_elemwise_multiple_clients(self):
176177
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None
177178

178179

180+
@pytest.mark.parametrize(
181+
"original_fn, expected_fn",
182+
[
183+
# Indexing before axis of reduction
184+
(lambda x: pt_sum(x, axis=2)[0], lambda x: pt_sum(x[0], axis=1)),
185+
(lambda x: pt_sum(x, axis=2)[0, 1], lambda x: pt_sum(x[0, 1], axis=None)),
186+
(lambda x: pt_sum(x, axis=2)[1:], lambda x: pt_sum(x[1:], axis=2)),
187+
# Indexing "at" axis of reduction
188+
(lambda x: pt_sum(x, axis=0)[2], lambda x: pt_sum(x[:, 2], axis=0)),
189+
(lambda x: pt_sum(x, axis=0)[:-2], lambda x: pt_sum(x[:, :-2], axis=0)),
190+
# Index after axis of reduction
191+
(lambda x: pt_sum(x, axis=0)[:, 1:], lambda x: pt_sum(x[:, :, 1:], axis=0)),
192+
# Index before and after axis reduction
193+
(lambda x: pt_sum(x, axis=1)[-2, 1:], lambda x: pt_sum(x[-2, :, 1:], axis=0)),
194+
(lambda x: pt_sum(x, axis=1)[1:, -2], lambda x: pt_sum(x[1:, :, -2], axis=1)),
195+
],
196+
)
197+
def test_local_subtensor_of_reduce(original_fn, expected_fn):
198+
rng = np.random.default_rng(245)
199+
x = pt.tensor("x", shape=(5, 3, 2))
200+
x_test = rng.normal(size=x.type.shape).astype(x.dtype)
201+
202+
out = original_fn(x)
203+
expected_opt_out = expected_fn(x)
204+
opt_out = rewrite_graph(out)
205+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
206+
[expected_opt_out, opt_out], print_type=True
207+
)
208+
np.testing.assert_allclose(
209+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
210+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
211+
)
212+
213+
179214
@pytest.mark.parametrize(
180215
"original_fn, expected_fn",
181216
[

0 commit comments

Comments
 (0)