Skip to content

Commit 2413d99

Browse files
committed
Fix _delta on non rightmost axes
1 parent a8993d6 commit 2413d99

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

pytensor/tensor/einsum.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pytensor.tensor.basic import (
1818
arange,
1919
as_tensor,
20+
expand_dims,
2021
get_vector_length,
2122
moveaxis,
2223
stack,
@@ -176,7 +177,8 @@ def _delta(shape: TensorVariable, axes: Sequence[int]) -> TensorVariable:
176177
iotas = [_iota(base_shape, i) for i in range(len(axes))]
177178
eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)]
178179
result = reduce(and_, eyes)
179-
return broadcast_to(result, shape)
180+
non_axes = [i for i in range(len(tuple(shape))) if i not in axes]
181+
return broadcast_to(expand_dims(result, non_axes), shape)
180182

181183

182184
def _general_dot(

tests/tensor/test_einsum.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def test_delta():
6767
[[1.0, 0.0], [0.0, 1.0]],
6868
)
6969

70+
np.testing.assert_allclose(
71+
_delta((2, 2, 2), (0, 1)).eval(mode=mode),
72+
[[[1, 1], [0, 0]], [[0, 0], [1, 1]]],
73+
)
74+
7075

7176
def test_general_dot():
7277
rng = np.random.default_rng(45)

0 commit comments

Comments
 (0)