Skip to content

Commit b7edb4f

Browse files
committed
MAINT: remove manual prod(sequence) -> math.prod
1 parent a3f3f37 commit b7edb4f

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,6 @@ def all_floating_dtypes() -> SearchStrategy[DataType]:
217217
# Size to use for 2-dim arrays
218218
SQRT_MAX_ARRAY_SIZE = int(math.sqrt(MAX_ARRAY_SIZE))
219219

220-
# np.prod and others have overflow and math.prod is Python 3.8+ only
221-
def prod(seq):
222-
return reduce(mul, seq, 1)
223220

224221
# hypotheses.strategies.tuples only generates tuples of a fixed size
225222
def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False):
@@ -233,7 +230,7 @@ def shapes(**kw):
233230
kw.setdefault('min_dims', 0)
234231
kw.setdefault('min_side', 0)
235232
return xps.array_shapes(**kw).filter(
236-
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
233+
lambda shape: math.prod(i for i in shape if i) < MAX_ARRAY_SIZE
237234
)
238235

239236

@@ -245,7 +242,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):
245242
stack_shape = draw(stack_shapes)
246243
mat_shape = draw(xps.array_shapes(max_dims=2, min_dims=2))
247244
shape = stack_shape + mat_shape
248-
assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE)
245+
assume(math.prod(i for i in shape if i) < MAX_ARRAY_SIZE)
249246
return shape
250247

251248
square_matrix_shapes = matrix_shapes().filter(lambda shape: shape[-1] == shape[-2])
@@ -290,7 +287,7 @@ def mutually_broadcastable_shapes(
290287
)
291288
.map(lambda BS: BS.input_shapes)
292289
.filter(lambda shapes: all(
293-
prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes
290+
math.prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes
294291
))
295292
)
296293

@@ -321,7 +318,7 @@ def positive_definite_matrices(draw, dtypes=floating_dtypes):
321318
base_shape = draw(shapes())
322319
n = draw(integers(0, 8)) # 8 is an arbitrary small but interesting-enough value
323320
shape = base_shape + (n, n)
324-
assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE)
321+
assume(math.prod(i for i in shape if i) < MAX_ARRAY_SIZE)
325322
dtype = draw(dtypes)
326323
return broadcast_to(eye(n, dtype=dtype), shape)
327324

0 commit comments

Comments
 (0)