Skip to content

Commit 32b14d2

Browse files
Add test for flip
1 parent 02566b6 commit 32b14d2

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

tests/tensor/test_pad.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33

44
import pytensor
5-
from pytensor.tensor.pad import PadMode, pad
5+
from pytensor.tensor.pad import PadMode, flip, pad
66

77

88
floatX = pytensor.config.floatX
@@ -132,3 +132,27 @@ def test_symmetric_pad(size, pad_width, reflect_type):
132132
f = pytensor.function([], z, mode="FAST_COMPILE")
133133

134134
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
135+
136+
137+
@pytest.mark.parametrize(
138+
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
139+
)
140+
def test_flip(size: tuple[int]):
141+
from itertools import combinations
142+
143+
x = np.random.normal(size=size).astype(floatX)
144+
x_pt = pytensor.tensor.tensor(shape=size, name="x")
145+
expected = np.flip(x, axis=None)
146+
z = flip(x_pt, axis=None)
147+
f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
148+
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)
149+
150+
# Test all combinations of axes
151+
flip_options = [
152+
axes for i in range(1, x.ndim + 1) for axes in combinations(range(x.ndim), r=i)
153+
]
154+
for axes in flip_options:
155+
expected = np.flip(x, axis=list(axes))
156+
z = flip(x_pt, axis=list(axes))
157+
f = pytensor.function([x_pt], z, mode="FAST_COMPILE")
158+
np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL)

0 commit comments

Comments
 (0)