Skip to content

Commit d7edde2

Browse files
committed
Fix constant number of steps reduction in ScanSaveMem rewrite
isinstance(..., int) does not recognize numpy.integers Also remove maxsize logic
1 parent b27c59d commit d7edde2

File tree

3 files changed

+70
-24
lines changed

3 files changed

+70
-24
lines changed

pytensor/link/jax/dispatch/scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def scan(*outer_inputs):
2929
# Extract JAX scan inputs
3030
outer_inputs = list(outer_inputs)
3131
n_steps = outer_inputs[0] # JAX `length`
32-
seqs = op.outer_seqs(outer_inputs) # JAX `xs`
32+
seqs = [seq[:n_steps] for seq in op.outer_seqs(outer_inputs)] # JAX `xs`
3333

3434
mit_sot_init = []
3535
for tap, seq in zip(

pytensor/scan/rewriting.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import copy
44
import dataclasses
55
from itertools import chain
6-
from sys import maxsize
76
from typing import cast
87

98
import numpy as np
@@ -1351,10 +1350,9 @@ def scan_save_mem(fgraph, node):
13511350
get_scalar_constant_value(cf_slice[0], raise_not_constant=False)
13521351
+ 1
13531352
)
1354-
if stop == maxsize or stop == get_scalar_constant_value(
1355-
length, raise_not_constant=False
1356-
):
1353+
if stop == get_scalar_constant_value(length, raise_not_constant=False):
13571354
stop = None
1355+
global_nsteps = None
13581356
else:
13591357
# there is a **gotcha** here ! Namely, scan returns an
13601358
# array that contains the initial state of the output
@@ -1366,21 +1364,13 @@ def scan_save_mem(fgraph, node):
13661364
# initial state)
13671365
stop = stop - init_l[i]
13681366

1369-
# 2.3.3 we might get away with less number of steps
1367+
# 2.3.3 we might get away with fewer steps
13701368
if stop is not None and global_nsteps is not None:
13711369
# yes if it is a tensor
13721370
if isinstance(stop, Variable):
13731371
global_nsteps["sym"] += [stop]
1374-
# not if it is maxsize
1375-
elif isinstance(stop, int) and stop == maxsize:
1376-
global_nsteps = None
1377-
# yes if it is a int k, 0 < k < maxsize
1378-
elif isinstance(stop, int) and global_nsteps["real"] < stop:
1379-
global_nsteps["real"] = stop
1380-
# yes if it is a int k, 0 < k < maxsize
1381-
elif isinstance(stop, int) and stop > 0:
1382-
pass
1383-
# not otherwise
1372+
elif isinstance(stop, int | np.integer):
1373+
global_nsteps["real"] = max(global_nsteps["real"], stop)
13841374
else:
13851375
global_nsteps = None
13861376

@@ -1703,10 +1693,7 @@ def scan_save_mem(fgraph, node):
17031693
- init_l[pos]
17041694
+ store_steps[pos]
17051695
)
1706-
if (
1707-
cnf_slice[0].stop is not None
1708-
and cnf_slice[0].stop != maxsize
1709-
):
1696+
if cnf_slice[0].stop is not None:
17101697
stop = (
17111698
cnf_slice[0].stop
17121699
- nw_steps

tests/scan/test_rewriting.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytensor.compile.mode import get_default_mode
1010
from pytensor.configdefaults import config
1111
from pytensor.gradient import grad, jacobian
12-
from pytensor.graph.basic import equal_computations
12+
from pytensor.graph.basic import Constant, equal_computations
1313
from pytensor.graph.fg import FunctionGraph
1414
from pytensor.graph.replace import clone_replace
1515
from pytensor.scan.op import Scan
@@ -1208,7 +1208,7 @@ def test_inplace3(self):
12081208

12091209

12101210
class TestSaveMem:
1211-
mode = get_default_mode().including("scan_save_mem", "scan_save_mem")
1211+
mode = get_default_mode().including("scan_save_mem")
12121212

12131213
def test_save_mem(self):
12141214
rng = np.random.default_rng(utt.fetch_seed())
@@ -1295,11 +1295,27 @@ def f_rnn(u_t):
12951295
[x1[:2], x2[4], x3[idx], x4[:idx], x5[-10], x6[-jdx], x7[:-jdx]],
12961296
updates=updates,
12971297
allow_input_downcast=True,
1298-
mode=self.mode,
1298+
mode=self.mode.excluding("scan_push_out_seq"),
12991299
)
1300+
# Check we actually have a Scan in the compiled function
1301+
[scan_node] = [
1302+
node for node in f2.maker.fgraph.toposort() if isinstance(node.op, Scan)
1303+
]
1304+
13001305
# get random initial values
13011306
rng = np.random.default_rng(utt.fetch_seed())
1302-
v_u = rng.uniform(-5.0, 5.0, size=(20,))
1307+
v_u = rng.uniform(-5.0, 5.0, size=(20,)).astype(u.type.dtype)
1308+
1309+
# Check the number of steps is actually reduced from 20
1310+
n_steps = scan_node.inputs[0]
1311+
n_steps_fn = pytensor.function(
1312+
[u, idx, jdx], n_steps, accept_inplace=True, on_unused_input="ignore"
1313+
)
1314+
assert n_steps_fn(u=v_u, idx=3, jdx=15) == 11 # x5[const=-10] requires 11 steps
1315+
assert n_steps_fn(u=v_u, idx=3, jdx=3) == 18 # x6[jdx=-3] requires 18 steps
1316+
assert n_steps_fn(u=v_u, idx=16, jdx=15) == 17 # x3[idx=16] requires 17 steps
1317+
assert n_steps_fn(u=v_u, idx=-5, jdx=15) == 16 # x3[idx=-5] requires 16 steps
1318+
assert n_steps_fn(u=v_u, idx=19, jdx=15) == 20 # x3[idx=19] requires 20 steps
13031319

13041320
# compute the output in numpy
13051321
tx1, tx2, tx3, tx4, tx5, tx6, tx7 = f2(v_u, 3, 15)
@@ -1312,6 +1328,49 @@ def f_rnn(u_t):
13121328
utt.assert_allclose(tx6, v_u[-15] + 6.0)
13131329
utt.assert_allclose(tx7, v_u[:-15] + 7.0)
13141330

1331+
def test_save_mem_reduced_number_of_steps_constant(self):
1332+
x0 = pt.scalar("x0")
1333+
xs, _ = scan(
1334+
lambda xtm1: xtm1 + 1,
1335+
outputs_info=[x0],
1336+
n_steps=10,
1337+
)
1338+
1339+
fn = function([x0], xs[:5], mode=self.mode)
1340+
[scan_node] = [
1341+
node for node in fn.maker.fgraph.toposort() if isinstance(node.op, Scan)
1342+
]
1343+
n_steps = scan_node.inputs[0]
1344+
assert isinstance(n_steps, Constant) and n_steps.data == 5
1345+
1346+
np.testing.assert_allclose(fn(0), np.arange(1, 11)[:5])
1347+
1348+
def test_save_mem_cannot_reduce_constant_number_of_steps(self):
1349+
x0 = pt.scalar("x0")
1350+
[xs, ys], _ = scan(
1351+
lambda xtm1, ytm1: (xtm1 + 1, ytm1 - 1),
1352+
outputs_info=[x0, x0],
1353+
n_steps=10,
1354+
)
1355+
1356+
# Because of ys[-1] we need all the steps!
1357+
fn = function([x0], [xs[:5], ys[-1]], mode=self.mode)
1358+
[scan_node] = [
1359+
node for node in fn.maker.fgraph.toposort() if isinstance(node.op, Scan)
1360+
]
1361+
n_steps = scan_node.inputs[0]
1362+
assert isinstance(n_steps, Constant) and n_steps.data == 10
1363+
1364+
res_x, res_y = fn(0)
1365+
np.testing.assert_allclose(
1366+
res_x,
1367+
np.arange(1, 11)[:5],
1368+
)
1369+
np.testing.assert_allclose(
1370+
res_y,
1371+
-np.arange(1, 11)[-1],
1372+
)
1373+
13151374
def test_save_mem_store_steps(self):
13161375
def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
13171376
return (

0 commit comments

Comments
 (0)