Skip to content

Commit 180ef9d

Browse files
committed
Make repeated indexes work
1 parent 3e958ce commit 180ef9d

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

pytensor/tensor/einsum.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from pytensor.compile.builders import OpFromGraph
99
from pytensor.tensor.basic import (
1010
arange,
11-
expand_dims,
1211
get_vector_length,
1312
stack,
1413
transpose,
@@ -36,9 +35,10 @@ def __init__(
3635

3736

3837
def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
39-
axis = normalize_axis_index(axis, get_vector_length(shape))
38+
len_shape = get_vector_length(shape)
39+
axis = normalize_axis_index(axis, len_shape)
4040
values = arange(shape[axis])
41-
return broadcast_to(shape_padright(values, axis), shape)
41+
return broadcast_to(shape_padright(values, len_shape - axis - 1), shape)
4242

4343

4444
def _delta(shape, axes: Sequence[int]) -> TensorVariable:
@@ -47,7 +47,7 @@ def _delta(shape, axes: Sequence[int]) -> TensorVariable:
4747
iotas = [_iota(base_shape, i) for i in range(len(axes))]
4848
eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)]
4949
result = reduce(and_, eyes)
50-
return broadcast_to(expand_dims(result, tuple(axes)), shape)
50+
return broadcast_to(result, shape)
5151

5252

5353
def _removechars(s, chars):

tests/tensor/test_einsum.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,39 @@
44
import pytest
55

66
import pytensor.tensor as pt
7+
from pytensor import Mode
8+
from pytensor.tensor.einsum import _delta, _iota
9+
10+
11+
def test_iota():
12+
mode = Mode(linker="py", optimizer=None)
13+
np.testing.assert_allclose(
14+
_iota((4, 8), 0).eval(mode=mode),
15+
[
16+
[0, 0, 0, 0, 0, 0, 0, 0],
17+
[1, 1, 1, 1, 1, 1, 1, 1],
18+
[2, 2, 2, 2, 2, 2, 2, 2],
19+
[3, 3, 3, 3, 3, 3, 3, 3],
20+
],
21+
)
22+
23+
np.testing.assert_allclose(
24+
_iota((4, 8), 1).eval(mode=mode),
25+
[
26+
[0, 1, 2, 3, 4, 5, 6, 7],
27+
[0, 1, 2, 3, 4, 5, 6, 7],
28+
[0, 1, 2, 3, 4, 5, 6, 7],
29+
[0, 1, 2, 3, 4, 5, 6, 7],
30+
],
31+
)
32+
33+
34+
def test_delta():
35+
mode = Mode(linker="py", optimizer=None)
36+
np.testing.assert_allclose(
37+
_delta((2, 2), (0, 1)).eval(mode=mode),
38+
[[1.0, 0.0], [0.0, 1.0]],
39+
)
740

841

942
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)