5
5
import pytest
6
6
7
7
import pytensor
8
- import pytensor .tensor as pt
9
8
from pytensor import Mode , config , function
10
9
from pytensor .graph import FunctionGraph
11
10
from pytensor .graph .op import HasInnerGraph
11
+ from pytensor .tensor .basic import moveaxis
12
12
from pytensor .tensor .blockwise import Blockwise
13
13
from pytensor .tensor .einsum import _delta , _general_dot , _iota , einsum
14
14
from pytensor .tensor .shape import Reshape
15
+ from pytensor .tensor .type import tensor
15
16
16
17
17
18
# Fail for unexpected warnings in this file
@@ -80,8 +81,8 @@ def test_general_dot():
80
81
81
82
# X has two batch dims
82
83
# Y has one batch dim
83
- x = pt . tensor ("x" , shape = (5 , 4 , 2 , 11 , 13 , 3 ))
84
- y = pt . tensor ("y" , shape = (4 , 13 , 5 , 7 , 11 ))
84
+ x = tensor ("x" , shape = (5 , 4 , 2 , 11 , 13 , 3 ))
85
+ y = tensor ("y" , shape = (4 , 13 , 5 , 7 , 11 ))
85
86
out = _general_dot ((x , y ), tensordot_axes , [(0 , 1 ), (0 ,)])
86
87
87
88
fn = pytensor .function ([x , y ], out )
@@ -135,10 +136,10 @@ def test_einsum_signatures(static_shape_known, signature):
135
136
static_shapes = [[None ] * len (shape ) for shape in shapes ]
136
137
137
138
operands = [
138
- pt . tensor (name , shape = static_shape )
139
+ tensor (name , shape = static_shape )
139
140
for name , static_shape in zip (ascii_lowercase , static_shapes , strict = False )
140
141
]
141
- out = pt . einsum (signature , * operands )
142
+ out = einsum (signature , * operands )
142
143
assert out .owner .op .optimized == static_shape_known or len (operands ) <= 2
143
144
144
145
rng = np .random .default_rng (37 )
@@ -160,8 +161,8 @@ def test_batch_dim():
160
161
"x" : (7 , 3 , 5 ),
161
162
"y" : (5 , 2 ),
162
163
}
163
- x , y = (pt . tensor (name , shape = shape ) for name , shape in shapes .items ())
164
- out = pt . einsum ("mij,jk->mik" , x , y )
164
+ x , y = (tensor (name , shape = shape ) for name , shape in shapes .items ())
165
+ out = einsum ("mij,jk->mik" , x , y )
165
166
166
167
assert out .type .shape == (7 , 3 , 2 )
167
168
@@ -195,32 +196,32 @@ def test_einsum_conv():
195
196
196
197
def test_ellipsis ():
197
198
rng = np .random .default_rng (159 )
198
- x = pt . tensor ("x" , shape = (3 , 5 , 7 , 11 ))
199
- y = pt . tensor ("y" , shape = (3 , 5 , 11 , 13 ))
199
+ x = tensor ("x" , shape = (3 , 5 , 7 , 11 ))
200
+ y = tensor ("y" , shape = (3 , 5 , 11 , 13 ))
200
201
x_test = rng .normal (size = x .type .shape ).astype (floatX )
201
202
y_test = rng .normal (size = y .type .shape ).astype (floatX )
202
203
expected_out = np .matmul (x_test , y_test )
203
204
204
205
with pytest .raises (ValueError ):
205
- pt . einsum ("mp,pn->mn" , x , y )
206
+ einsum ("mp,pn->mn" , x , y )
206
207
207
- out = pt . einsum ("...mp,...pn->...mn" , x , y )
208
+ out = einsum ("...mp,...pn->...mn" , x , y )
208
209
np .testing .assert_allclose (
209
210
out .eval ({x : x_test , y : y_test }), expected_out , atol = ATOL , rtol = RTOL
210
211
)
211
212
212
213
# Put batch axes in the middle
213
- new_x = pt . moveaxis (x , - 2 , 0 )
214
- new_y = pt . moveaxis (y , - 2 , 0 )
215
- out = pt . einsum ("m...p,p...n->m...n" , new_x , new_y )
214
+ new_x = moveaxis (x , - 2 , 0 )
215
+ new_y = moveaxis (y , - 2 , 0 )
216
+ out = einsum ("m...p,p...n->m...n" , new_x , new_y )
216
217
np .testing .assert_allclose (
217
218
out .eval ({x : x_test , y : y_test }),
218
219
expected_out .transpose (- 2 , 0 , 1 , - 1 ),
219
220
atol = ATOL ,
220
221
rtol = RTOL ,
221
222
)
222
223
223
- out = pt . einsum ("m...p,p...n->mn" , new_x , new_y )
224
+ out = einsum ("m...p,p...n->mn" , new_x , new_y )
224
225
np .testing .assert_allclose (
225
226
out .eval ({x : x_test , y : y_test }), expected_out .sum ((0 , 1 )), atol = ATOL , rtol = RTOL
226
227
)
@@ -236,9 +237,9 @@ def test_broadcastable_dims():
236
237
# can lead to suboptimal paths. We check we issue a warning for the following example:
237
238
# https://github.com/dgasmith/opt_einsum/issues/220
238
239
rng = np .random .default_rng (222 )
239
- a = pt . tensor ("a" , shape = (32 , 32 , 32 ))
240
- b = pt . tensor ("b" , shape = (1000 , 32 ))
241
- c = pt . tensor ("c" , shape = (1 , 32 ))
240
+ a = tensor ("a" , shape = (32 , 32 , 32 ))
241
+ b = tensor ("b" , shape = (1000 , 32 ))
242
+ c = tensor ("c" , shape = (1 , 32 ))
242
243
243
244
a_test = rng .normal (size = a .type .shape ).astype (floatX )
244
245
b_test = rng .normal (size = b .type .shape ).astype (floatX )
@@ -248,11 +249,11 @@ def test_broadcastable_dims():
248
249
with pytest .warns (
249
250
UserWarning , match = "This can result in a suboptimal contraction path"
250
251
):
251
- suboptimal_out = pt . einsum ("ijk,bj,bk->i" , a , b , c )
252
+ suboptimal_out = einsum ("ijk,bj,bk->i" , a , b , c )
252
253
assert not [set (p ) for p in suboptimal_out .owner .op .path ] == [{0 , 2 }, {0 , 1 }]
253
254
254
255
# If we use a distinct letter we get the optimal path
255
- optimal_out = pt . einsum ("ijk,bj,ck->i" , a , b , c )
256
+ optimal_out = einsum ("ijk,bj,ck->i" , a , b , c )
256
257
assert [set (p ) for p in optimal_out .owner .op .path ] == [{0 , 2 }, {0 , 1 }]
257
258
258
259
suboptimal_eval = suboptimal_out .eval ({a : a_test , b : b_test , c : c_test })
0 commit comments