Skip to content

Commit 10b4683

Browse files
committed
size_gt_1 support for assert_s_axes_shape()
1 parent ef95ba1 commit 10b4683

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

array_api_tests/test_fft.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def assert_n_axis_shape(
9999
n: Optional[int],
100100
axis: int,
101101
out: Array,
102-
size_gt_1=False,
102+
size_gt_1: bool = False,
103103
):
104104
_axis = len(x.shape) - 1 if axis == -1 else axis
105105
if n is None:
@@ -120,6 +120,7 @@ def assert_s_axes_shape(
120120
s: Optional[List[int]],
121121
axes: Optional[List[int]],
122122
out: Array,
123+
size_gt_1: bool = False,
123124
):
124125
_axes = sh.normalise_axis(axes, x.ndim)
125126
_s = x.shape if s is None else s
@@ -130,6 +131,10 @@ def assert_s_axes_shape(
130131
else:
131132
side = x.shape[i]
132133
expected.append(side)
134+
if size_gt_1:
135+
last_axis = _axes[-1]
136+
expected[last_axis] = 2 * (expected[last_axis] - 1)
137+
assume(expected[last_axis] > 0) # TODO: generate valid examples
133138
ph.assert_shape(func_name, out_shape=out.shape, expected=tuple(expected))
134139

135140

@@ -243,7 +248,7 @@ def test_irfftn(x, data):
243248
out = xp.fft.irfftn(x, **kwargs)
244249

245250
assert_fft_dtype("irfftn", in_dtype=x.dtype, out_dtype=out.dtype)
246-
# TODO: shape
251+
assert_s_axes_shape("rfftn", x=x, s=s, axes=axes, out=out, size_gt_1=True)
247252

248253

249254
@given(

0 commit comments

Comments
 (0)