1
1
import numpy as np
2
2
import pytest
3
- import unittest_tools as utt
4
3
5
4
from pytensor import (
6
5
Mode ,
25
24
from pytensor .tensor import (
26
25
add ,
27
26
exp ,
28
- inplace ,
29
27
iscalar ,
30
28
iscalars ,
31
29
lscalar ,
32
30
lscalars ,
33
31
matrix ,
34
- scalar ,
35
32
shape ,
36
33
slicetype ,
37
34
specify_shape ,
43
40
from pytensor .tensor .elemwise import DimShuffle , Elemwise
44
41
from pytensor .tensor .rewriting .subtensor_lift import (
45
42
local_subtensor_make_vector ,
43
+ local_subtensor_of_elemwise ,
46
44
local_subtensor_shape_constant ,
47
45
)
48
46
from pytensor .tensor .shape import SpecifyShape , _shape
58
56
NO_OPTIMIZATION_MODE = Mode (linker = "py" , optimizer = None )
59
57
60
58
61
- class TestLocalSubtensorLift :
62
- def test_basic (self ):
63
- # basic test that the Op works
64
- x = matrix ("x" )
65
- f = function ([x ], exp (x )[0 ], mode = mode_opt )
66
-
67
- # Check stacktrace was copied over correctly after opt was applied
68
- assert check_stack_trace (f , ops_to_check = "all" )
69
-
70
- prog = f .maker .fgraph .toposort ()
71
- assert isinstance (prog [0 ].op , Subtensor ) # first subtensor
72
- assert prog [1 ].op == exp
73
- assert len (prog ) == 2
74
- f ([[0 , 1 ], [2 , 3 ]]) # let debugmode test something
75
-
76
- def test_basic_1 (self ):
59
+ class TestLocalSubtensorOfElemwise :
60
+ def test_unary_multiple_clients (self ):
77
61
# as test0, but we reuse the output of the elemwise
78
62
# So we should not lift the subtensor
79
63
x = matrix ("x" )
@@ -87,85 +71,16 @@ def test_basic_1(self):
87
71
assert isinstance (prog [1 ].op , Subtensor ) # first subtensor
88
72
assert isinstance (prog [2 ].op , DeepCopyOp )
89
73
assert len (prog ) == 3
90
- f ([[0 , 1 ], [2 , 3 ]]) # let debugmode test something
91
-
92
- def test_basic_2 (self ):
93
- # basic test that the optimization work with scalar broadcasted
94
- x = matrix ("x" )
95
- y = scalar ("y" )
96
- z = matrix ("z" )
97
- f = function ([x , y , z ], exp (x + y + z )[0 ], mode = mode_opt )
98
-
99
- prog = f .maker .fgraph .toposort ()
100
- assert isinstance (prog [0 ].op , Subtensor )
101
- assert isinstance (prog [1 ].op , DimShuffle )
102
- assert isinstance (prog [2 ].op , Subtensor )
103
- assert isinstance (prog [3 ].op .scalar_op , ps .Composite ) # Composite{add,add}
104
- assert len (prog ) == 4
105
-
106
- # Check stacktrace was copied over correctly after opt was applied
107
- assert check_stack_trace (f , ops_to_check = [Subtensor ])
108
-
109
- # let debugmode test something
110
- f ([[0 , 1 ], [2 , 3 ]], 4 , [[4 , 5 ], [6 , 7 ]])
111
-
112
- def test_basic_3 (self ):
113
- # as 1, but take a slice
114
- x = matrix ("x" )
115
- y = scalar ("y" )
116
- z = matrix ("z" )
117
- f = function ([x , y , z ], exp (x + y + z )[0 :2 ], mode = mode_opt )
118
-
119
- prog = f .maker .fgraph .toposort ()
120
- assert isinstance (prog [0 ].op , Subtensor )
121
- assert isinstance (prog [1 ].op , DimShuffle )
122
- assert isinstance (prog [2 ].op , Subtensor )
123
- assert isinstance (prog [3 ].op .scalar_op , ps .Composite ) # Composite{add,add}
124
- assert len (prog ) == 4
125
-
126
- # Check stacktrace was copied over correctly after opt was applied
127
- assert check_stack_trace (f , ops_to_check = [Subtensor ])
128
-
129
- # let debugmode test something
130
- f ([[0 , 1 ], [2 , 3 ]], 4 , [[4 , 5 ], [6 , 7 ]])
131
-
132
- def test_basic_4 (self ):
133
- # basic test that the optimization does work with broadcasting
134
- # for unary elemwise.
135
- y = vector ("y" )
136
- f = function ([y ], exp (y .dimshuffle (0 , "x" ))[0 ], mode = mode_opt )
137
-
138
- # Check stacktrace was copied over correctly after opt was applied
139
- assert check_stack_trace (f , ops_to_check = "all" )
140
-
141
- prog = f .maker .fgraph .toposort ()
142
- assert isinstance (prog [0 ].op , Subtensor )
143
- assert isinstance (prog [1 ].op , DimShuffle )
144
- assert prog [2 ].op == exp
145
- assert len (prog ) == 3
146
- f ([4 , 5 ]) # let debugmode test something
147
-
148
- @utt .assertFailure_fast
149
- def test_basic_5 (self ):
150
- # basic test that the optimization doesn't work with broadcasting
151
- # ... It *could* be extended to,
152
- # ... but right now it doesn't, so it shouldn't try.
153
- x = matrix ("x" )
154
- y = vector ("y" )
155
- f = function ([x , y ], exp (x + y )[0 ], mode = mode_opt )
156
74
157
- # Opt doesn't apply, so no need for check_stack_trace
158
- # assert check_stack_trace(f, ops_to_check='all')
159
-
160
- prog = f .maker .fgraph .toposort ()
161
- assert isinstance (prog [0 ].op , DimShuffle )
162
- assert prog [1 ].op == add
163
- assert isinstance (prog [2 ].op , Subtensor ) # first subtensor
164
- assert prog [3 ].op == inplace .exp_inplace
165
- assert len (prog ) == 4
166
- f ([[0 , 1 ], [2 , 3 ]], [4 , 5 ]) # let debugmode test something
75
+ x_test = [[0 , 1 ], [2 , 3 ]]
76
+ res1 , res2 = f (x_test )
77
+ np .testing .assert_allclose (
78
+ res1 ,
79
+ np .exp (x_test )[0 ],
80
+ )
81
+ np .testing .assert_allclose (res2 , np .exp (x_test ))
167
82
168
- def test_basic_6 (self ):
83
+ def test_multinary_multiple_clients (self ):
169
84
# test that we don't lift when we reuse the output of the
170
85
# elemwise for other computation.
171
86
x = matrix ("x" )
@@ -181,26 +96,84 @@ def test_basic_6(self):
181
96
# first subtensor
182
97
assert isinstance (prog [2 ].op , Subtensor )
183
98
assert len (prog ) == 3
184
- f ([[0 , 1 ], [2 , 3 ]], [4 , 5 ]) # let debugmode test something
185
99
186
- def test_basic_7 (self ):
187
- # basic test that the optimization works with a scalar as input,
188
- # and a scalar as output (no broadcasting of the scalar needed).
189
- # The optimization used to fail and display an ERROR message.
100
+ x_test = np .array ([[0 , 1 ], [2 , 3 ]]).astype (x .dtype )
101
+ y_test = np .array ([4 , 5 ]).astype (y .dtype )
102
+ res1 , res2 = f (x_test , y_test )
103
+ np .testing .assert_allclose (
104
+ res1 ,
105
+ np .exp (x_test + y_test )[0 ],
106
+ )
107
+ np .testing .assert_allclose (
108
+ res2 ,
109
+ np .exp (x_test + y_test ) + x_test ,
110
+ )
111
+
112
+ @pytest .mark .parametrize (
113
+ "original_fn, expected_fn" ,
114
+ [
115
+ # Unary integer indexing
116
+ (lambda x , y : exp (x )[0 ], lambda x , y : exp (x [0 ])),
117
+ # Unary integer with expand_dims
118
+ (lambda x , y : exp (x [:, None ])[0 ], lambda x , y : exp (x [0 ][None ])),
119
+ # Integer indexing on non-broadcastable dimension
120
+ (lambda x , y : add (x , y )[0 ], lambda x , y : add (x [0 ], y [0 ])),
121
+ # Slice indexing on non-broadcastable dimension
122
+ (lambda x , y : add (x , y )[1 :], lambda x , y : add (x [1 :], y [1 :])),
123
+ # Integer indexing on broacastable dimension
124
+ (lambda x , y : add (x [None ], y [None ])[0 ], lambda x , y : add (x , y )),
125
+ (lambda x , y : add (x [None ], y [None ])[0 , 1 ], lambda x , y : add (x [1 ], y [1 ])),
126
+ (
127
+ lambda x , y : add (x [None , :], y [:, None ])[2 ],
128
+ lambda x , y : add (x , y [2 ][None ]),
129
+ ),
130
+ (
131
+ lambda x , y : add (x [:, None ], y [None , :])[:, 2 ],
132
+ lambda x , y : add (x , y [2 ][None ]),
133
+ ),
134
+ # Slice indexing on broadcastable dimension
135
+ (
136
+ lambda x , y : add (x [None ], y [None ])[1 :],
137
+ lambda x , y : add (x [None ][1 :], y [None ][1 :]),
138
+ ),
139
+ (
140
+ lambda x , y : add (x [None , :], y [:, None ])[1 :],
141
+ lambda x , y : add (x [None , :], y [1 :][:, None ]),
142
+ ),
143
+ ],
144
+ )
145
+ def test_local_subtensor_of_elemwise (self , original_fn , expected_fn ):
146
+ rng = np .random .default_rng (257 )
147
+ x = pt .matrix ("x" , shape = (5 , 3 ))
148
+ y = pt .matrix ("y" , shape = (5 , 3 ))
149
+ x_test = rng .normal (size = x .type .shape ).astype (x .dtype )
150
+ y_test = rng .normal (size = y .type .shape ).astype (y .dtype )
151
+
152
+ out = original_fn (x , y )
153
+ expected_opt_out = expected_fn (x , y )
154
+ opt_out = rewrite_graph (out )
155
+ assert equal_computations ([opt_out ], [expected_opt_out ]), debugprint (
156
+ [expected_opt_out , opt_out ], print_type = True
157
+ )
158
+ eval_kwargs = dict (mode = NO_OPTIMIZATION_MODE , on_unused_input = "ignore" )
159
+ np .testing .assert_allclose (
160
+ opt_out .eval ({x : x_test , y : y_test }, ** eval_kwargs ),
161
+ out .eval ({x : x_test , y : y_test }, ** eval_kwargs ),
162
+ )
190
163
191
- x = vector ("x" )
192
- y = scalar ("y" )
193
- f = function ([x , y ], exp (x + y )[0 ], mode = mode_opt )
164
+ def test_local_subtensor_of_elemwise_multiple_clients (self ):
165
+ x = pt .matrix ("x" , shape = (5 , 3 ))
166
+ y = pt .matrix ("y" , shape = (5 , 3 ))
167
+ out1 = add (x , y )
168
+ out2 = out1 [0 ]
194
169
195
- # Check stacktrace was copied over correctly after opt was applied
196
- assert check_stack_trace (f , ops_to_check = Subtensor )
170
+ # Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
171
+ fgraph = FunctionGraph ([x , y ], [out1 , out2 ], clone = False )
172
+ assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is None
197
173
198
- prog = f .maker .fgraph .toposort ()
199
- assert isinstance (prog [0 ].op , Subtensor )
200
- # Composite{add,exp}
201
- assert isinstance (prog [1 ].op .scalar_op , ps .Composite )
202
- assert len (prog ) == 2
203
- f ([1 , 2 , 3 ], 4 ) # let debugmode test something
174
+ # Otherwise it should work
175
+ fgraph .remove_output (0 )
176
+ assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is not None
204
177
205
178
206
179
@pytest .mark .parametrize (
0 commit comments