9
9
from pytensor .compile .mode import get_default_mode
10
10
from pytensor .configdefaults import config
11
11
from pytensor .gradient import grad , jacobian
12
- from pytensor .graph .basic import equal_computations
12
+ from pytensor .graph .basic import Constant , equal_computations
13
13
from pytensor .graph .fg import FunctionGraph
14
14
from pytensor .graph .replace import clone_replace
15
15
from pytensor .scan .op import Scan
@@ -1208,7 +1208,7 @@ def test_inplace3(self):
1208
1208
1209
1209
1210
1210
class TestSaveMem :
1211
- mode = get_default_mode ().including ("scan_save_mem" , "scan_save_mem" )
1211
+ mode = get_default_mode ().including ("scan_save_mem" )
1212
1212
1213
1213
def test_save_mem (self ):
1214
1214
rng = np .random .default_rng (utt .fetch_seed ())
@@ -1295,11 +1295,27 @@ def f_rnn(u_t):
1295
1295
[x1 [:2 ], x2 [4 ], x3 [idx ], x4 [:idx ], x5 [- 10 ], x6 [- jdx ], x7 [:- jdx ]],
1296
1296
updates = updates ,
1297
1297
allow_input_downcast = True ,
1298
- mode = self .mode ,
1298
+ mode = self .mode . excluding ( "scan_push_out_seq" ) ,
1299
1299
)
1300
+ # Check we actually have a Scan in the compiled function
1301
+ [scan_node ] = [
1302
+ node for node in f2 .maker .fgraph .toposort () if isinstance (node .op , Scan )
1303
+ ]
1304
+
1300
1305
# get random initial values
1301
1306
rng = np .random .default_rng (utt .fetch_seed ())
1302
- v_u = rng .uniform (- 5.0 , 5.0 , size = (20 ,))
1307
+ v_u = rng .uniform (- 5.0 , 5.0 , size = (20 ,)).astype (u .type .dtype )
1308
+
1309
+ # Check the number of steps is actually reduced from 20
1310
+ n_steps = scan_node .inputs [0 ]
1311
+ n_steps_fn = pytensor .function (
1312
+ [u , idx , jdx ], n_steps , accept_inplace = True , on_unused_input = "ignore"
1313
+ )
1314
+ assert n_steps_fn (u = v_u , idx = 3 , jdx = 15 ) == 11 # x5[const=-10] requires 11 steps
1315
+ assert n_steps_fn (u = v_u , idx = 3 , jdx = 3 ) == 18 # x6[jdx=-3] requires 18 steps
1316
+ assert n_steps_fn (u = v_u , idx = 16 , jdx = 15 ) == 17 # x3[idx=16] requires 17 steps
1317
+ assert n_steps_fn (u = v_u , idx = - 5 , jdx = 15 ) == 16 # x3[idx=-5] requires 16 steps
1318
+ assert n_steps_fn (u = v_u , idx = 19 , jdx = 15 ) == 20 # x3[idx=19] requires 20 steps
1303
1319
1304
1320
# compute the output in numpy
1305
1321
tx1 , tx2 , tx3 , tx4 , tx5 , tx6 , tx7 = f2 (v_u , 3 , 15 )
@@ -1312,6 +1328,49 @@ def f_rnn(u_t):
1312
1328
utt .assert_allclose (tx6 , v_u [- 15 ] + 6.0 )
1313
1329
utt .assert_allclose (tx7 , v_u [:- 15 ] + 7.0 )
1314
1330
1331
+ def test_save_mem_reduced_number_of_steps_constant (self ):
1332
+ x0 = pt .scalar ("x0" )
1333
+ xs , _ = scan (
1334
+ lambda xtm1 : xtm1 + 1 ,
1335
+ outputs_info = [x0 ],
1336
+ n_steps = 10 ,
1337
+ )
1338
+
1339
+ fn = function ([x0 ], xs [:5 ], mode = self .mode )
1340
+ [scan_node ] = [
1341
+ node for node in fn .maker .fgraph .toposort () if isinstance (node .op , Scan )
1342
+ ]
1343
+ n_steps = scan_node .inputs [0 ]
1344
+ assert isinstance (n_steps , Constant ) and n_steps .data == 5
1345
+
1346
+ np .testing .assert_allclose (fn (0 ), np .arange (1 , 11 )[:5 ])
1347
+
1348
+ def test_save_mem_cannot_reduce_constant_number_of_steps (self ):
1349
+ x0 = pt .scalar ("x0" )
1350
+ [xs , ys ], _ = scan (
1351
+ lambda xtm1 , ytm1 : (xtm1 + 1 , ytm1 - 1 ),
1352
+ outputs_info = [x0 , x0 ],
1353
+ n_steps = 10 ,
1354
+ )
1355
+
1356
+ # Because of ys[-1] we need all the steps!
1357
+ fn = function ([x0 ], [xs [:5 ], ys [- 1 ]], mode = self .mode )
1358
+ [scan_node ] = [
1359
+ node for node in fn .maker .fgraph .toposort () if isinstance (node .op , Scan )
1360
+ ]
1361
+ n_steps = scan_node .inputs [0 ]
1362
+ assert isinstance (n_steps , Constant ) and n_steps .data == 10
1363
+
1364
+ res_x , res_y = fn (0 )
1365
+ np .testing .assert_allclose (
1366
+ res_x ,
1367
+ np .arange (1 , 11 )[:5 ],
1368
+ )
1369
+ np .testing .assert_allclose (
1370
+ res_y ,
1371
+ - np .arange (1 , 11 )[- 1 ],
1372
+ )
1373
+
1315
1374
def test_save_mem_store_steps (self ):
1316
1375
def f_rnn (u_t , x1_tm1 , x1_tm3 , x2_tm1 , x3tm2 , x3_tm1 , x4_tm1 ):
1317
1376
return (
0 commit comments