|
2 | 2 | import pytest
|
3 | 3 |
|
4 | 4 | import pytensor
|
5 |
| -from pytensor.tensor.pad import PadMode, pad |
| 5 | +from pytensor.tensor.pad import PadMode, flip, pad |
6 | 6 |
|
7 | 7 |
|
8 | 8 | floatX = pytensor.config.floatX
|
@@ -132,3 +132,27 @@ def test_symmetric_pad(size, pad_width, reflect_type):
|
132 | 132 | f = pytensor.function([], z, mode="FAST_COMPILE")
|
133 | 133 |
|
134 | 134 | np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
|
| 135 | + |
| 136 | + |
| 137 | +@pytest.mark.parametrize( |
| 138 | + "size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"] |
| 139 | +) |
| 140 | +def test_flip(size: tuple[int]): |
| 141 | + from itertools import combinations |
| 142 | + |
| 143 | + x = np.random.normal(size=size).astype(floatX) |
| 144 | + x_pt = pytensor.tensor.tensor(shape=size, name="x") |
| 145 | + expected = np.flip(x, axis=None) |
| 146 | + z = flip(x_pt, axis=None) |
| 147 | + f = pytensor.function([x_pt], z, mode="FAST_COMPILE") |
| 148 | + np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL) |
| 149 | + |
| 150 | + # Test all combinations of axes |
| 151 | + flip_options = [ |
| 152 | + axes for i in range(1, x.ndim + 1) for axes in combinations(range(x.ndim), r=i) |
| 153 | + ] |
| 154 | + for axes in flip_options: |
| 155 | + expected = np.flip(x, axis=list(axes)) |
| 156 | + z = flip(x_pt, axis=list(axes)) |
| 157 | + f = pytensor.function([x_pt], z, mode="FAST_COMPILE") |
| 158 | + np.testing.assert_allclose(expected, f(x), atol=ATOL, rtol=RTOL) |
0 commit comments