Skip to content

Commit c4a3444

Browse files
committed
Use explicit imports in test_einsum
1 parent 884dee9 commit c4a3444

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

tests/tensor/test_einsum.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
import pytest
66

77
import pytensor
8-
import pytensor.tensor as pt
98
from pytensor import Mode, config, function
109
from pytensor.graph import FunctionGraph
1110
from pytensor.graph.op import HasInnerGraph
11+
from pytensor.tensor.basic import moveaxis
1212
from pytensor.tensor.blockwise import Blockwise
1313
from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum
1414
from pytensor.tensor.shape import Reshape
15+
from pytensor.tensor.type import tensor
1516

1617

1718
# Fail for unexpected warnings in this file
@@ -80,8 +81,8 @@ def test_general_dot():
8081

8182
# X has two batch dims
8283
# Y has one batch dim
83-
x = pt.tensor("x", shape=(5, 4, 2, 11, 13, 3))
84-
y = pt.tensor("y", shape=(4, 13, 5, 7, 11))
84+
x = tensor("x", shape=(5, 4, 2, 11, 13, 3))
85+
y = tensor("y", shape=(4, 13, 5, 7, 11))
8586
out = _general_dot((x, y), tensordot_axes, [(0, 1), (0,)])
8687

8788
fn = pytensor.function([x, y], out)
@@ -135,10 +136,10 @@ def test_einsum_signatures(static_shape_known, signature):
135136
static_shapes = [[None] * len(shape) for shape in shapes]
136137

137138
operands = [
138-
pt.tensor(name, shape=static_shape)
139+
tensor(name, shape=static_shape)
139140
for name, static_shape in zip(ascii_lowercase, static_shapes, strict=False)
140141
]
141-
out = pt.einsum(signature, *operands)
142+
out = einsum(signature, *operands)
142143
assert out.owner.op.optimized == static_shape_known or len(operands) <= 2
143144

144145
rng = np.random.default_rng(37)
@@ -160,8 +161,8 @@ def test_batch_dim():
160161
"x": (7, 3, 5),
161162
"y": (5, 2),
162163
}
163-
x, y = (pt.tensor(name, shape=shape) for name, shape in shapes.items())
164-
out = pt.einsum("mij,jk->mik", x, y)
164+
x, y = (tensor(name, shape=shape) for name, shape in shapes.items())
165+
out = einsum("mij,jk->mik", x, y)
165166

166167
assert out.type.shape == (7, 3, 2)
167168

@@ -195,32 +196,32 @@ def test_einsum_conv():
195196

196197
def test_ellipsis():
197198
rng = np.random.default_rng(159)
198-
x = pt.tensor("x", shape=(3, 5, 7, 11))
199-
y = pt.tensor("y", shape=(3, 5, 11, 13))
199+
x = tensor("x", shape=(3, 5, 7, 11))
200+
y = tensor("y", shape=(3, 5, 11, 13))
200201
x_test = rng.normal(size=x.type.shape).astype(floatX)
201202
y_test = rng.normal(size=y.type.shape).astype(floatX)
202203
expected_out = np.matmul(x_test, y_test)
203204

204205
with pytest.raises(ValueError):
205-
pt.einsum("mp,pn->mn", x, y)
206+
einsum("mp,pn->mn", x, y)
206207

207-
out = pt.einsum("...mp,...pn->...mn", x, y)
208+
out = einsum("...mp,...pn->...mn", x, y)
208209
np.testing.assert_allclose(
209210
out.eval({x: x_test, y: y_test}), expected_out, atol=ATOL, rtol=RTOL
210211
)
211212

212213
# Put batch axes in the middle
213-
new_x = pt.moveaxis(x, -2, 0)
214-
new_y = pt.moveaxis(y, -2, 0)
215-
out = pt.einsum("m...p,p...n->m...n", new_x, new_y)
214+
new_x = moveaxis(x, -2, 0)
215+
new_y = moveaxis(y, -2, 0)
216+
out = einsum("m...p,p...n->m...n", new_x, new_y)
216217
np.testing.assert_allclose(
217218
out.eval({x: x_test, y: y_test}),
218219
expected_out.transpose(-2, 0, 1, -1),
219220
atol=ATOL,
220221
rtol=RTOL,
221222
)
222223

223-
out = pt.einsum("m...p,p...n->mn", new_x, new_y)
224+
out = einsum("m...p,p...n->mn", new_x, new_y)
224225
np.testing.assert_allclose(
225226
out.eval({x: x_test, y: y_test}), expected_out.sum((0, 1)), atol=ATOL, rtol=RTOL
226227
)
@@ -236,9 +237,9 @@ def test_broadcastable_dims():
236237
# can lead to suboptimal paths. We check we issue a warning for the following example:
237238
# https://github.com/dgasmith/opt_einsum/issues/220
238239
rng = np.random.default_rng(222)
239-
a = pt.tensor("a", shape=(32, 32, 32))
240-
b = pt.tensor("b", shape=(1000, 32))
241-
c = pt.tensor("c", shape=(1, 32))
240+
a = tensor("a", shape=(32, 32, 32))
241+
b = tensor("b", shape=(1000, 32))
242+
c = tensor("c", shape=(1, 32))
242243

243244
a_test = rng.normal(size=a.type.shape).astype(floatX)
244245
b_test = rng.normal(size=b.type.shape).astype(floatX)
@@ -248,11 +249,11 @@ def test_broadcastable_dims():
248249
with pytest.warns(
249250
UserWarning, match="This can result in a suboptimal contraction path"
250251
):
251-
suboptimal_out = pt.einsum("ijk,bj,bk->i", a, b, c)
252+
suboptimal_out = einsum("ijk,bj,bk->i", a, b, c)
252253
assert not [set(p) for p in suboptimal_out.owner.op.path] == [{0, 2}, {0, 1}]
253254

254255
# If we use a distinct letter we get the optimal path
255-
optimal_out = pt.einsum("ijk,bj,ck->i", a, b, c)
256+
optimal_out = einsum("ijk,bj,ck->i", a, b, c)
256257
assert [set(p) for p in optimal_out.owner.op.path] == [{0, 2}, {0, 1}]
257258

258259
suboptimal_eval = suboptimal_out.eval({a: a_test, b: b_test, c: c_test})

0 commit comments

Comments
 (0)