1
1
import math
2
2
from collections import deque
3
- from typing import Iterable , Union
3
+ from typing import Iterable , Iterator , Tuple , Union
4
4
5
5
import pytest
6
6
from hypothesis import assume , given
@@ -33,8 +33,10 @@ def assert_array_ndindex(
33
33
x_indices : Iterable [Union [int , Shape ]],
34
34
out : Array ,
35
35
out_indices : Iterable [Union [int , Shape ]],
36
+ / ,
37
+ ** kw ,
36
38
):
37
- msg_suffix = f" [{ func_name } ()]\n { x = } \n { out = } "
39
+ msg_suffix = f" [{ func_name } ({ ph . fmt_kw ( kw ) } )]\n { x = } \n { out = } "
38
40
for x_idx , out_idx in zip (x_indices , out_indices ):
39
41
msg = f"out[{ out_idx } ]={ out [out_idx ]} , should be x[{ x_idx } ]={ x [x_idx ]} "
40
42
msg += msg_suffix
@@ -266,7 +268,15 @@ def test_reshape(x, data):
266
268
assert_array_ndindex ("reshape" , x , sh .ndindex (x .shape ), out , sh .ndindex (out .shape ))
267
269
268
270
269
- @pytest .mark .skip (reason = "faulty test logic" ) # TODO
271
+ def roll_ndindex (shape : Shape , shifts : Tuple [int ], axes : Tuple [int ]) -> Iterator [Shape ]:
272
+ assert len (shifts ) == len (axes ) # sanity check
273
+ all_shifts = [0 for _ in shape ]
274
+ for s , a in zip (shifts , axes ):
275
+ all_shifts [a ] = s
276
+ for idx in sh .ndindex (shape ):
277
+ yield tuple ((i + sh ) % si for i , sh , si in zip (idx , all_shifts , shape ))
278
+
279
+
270
280
@given (xps .arrays (dtype = xps .scalar_dtypes (), shape = shared_shapes ()), st .data ())
271
281
def test_roll (x , data ):
272
282
shift_strat = st .integers (- hh .MAX_ARRAY_SIZE , hh .MAX_ARRAY_SIZE )
@@ -287,6 +297,8 @@ def test_roll(x, data):
287
297
288
298
out = xp .roll (x , shift , ** kw )
289
299
300
+ kw = {"shift" : shift , ** kw } # for error messages
301
+
290
302
ph .assert_dtype ("roll" , x .dtype , out .dtype )
291
303
292
304
ph .assert_result_shape ("roll" , (x .shape ,), out .shape )
@@ -296,18 +308,12 @@ def test_roll(x, data):
296
308
indices = list (sh .ndindex (x .shape ))
297
309
shifted_indices = deque (indices )
298
310
shifted_indices .rotate (- shift )
299
- assert_array_ndindex ("roll" , x , indices , out , shifted_indices )
311
+ assert_array_ndindex ("roll" , x , indices , out , shifted_indices , ** kw )
300
312
else :
301
- _shift = (shift ,) if isinstance (shift , int ) else shift
313
+ shifts = (shift ,) if isinstance (shift , int ) else shift
302
314
axes = sh .normalise_axis (kw ["axis" ], x .ndim )
303
- all_indices = list (sh .ndindex (x .shape ))
304
- for s , a in zip (_shift , axes ):
305
- side = x .shape [a ]
306
- for i in range (side ):
307
- indices = [idx for idx in all_indices if idx [a ] == i ]
308
- shifted_indices = deque (indices )
309
- shifted_indices .rotate (- s )
310
- assert_array_ndindex ("roll" , x , indices , out , shifted_indices )
315
+ shifted_indices = roll_ndindex (x .shape , shifts , axes )
316
+ assert_array_ndindex ("roll" , x , sh .ndindex (x .shape ), out , shifted_indices , ** kw )
311
317
312
318
313
319
@given (
0 commit comments