1
1
import re
2
2
import itertools
3
3
from contextlib import contextmanager
4
- from functools import reduce
5
- from math import sqrt
4
+ from functools import reduce , wraps
5
+ import math
6
6
from operator import mul
7
- from typing import Any , List , NamedTuple , Optional , Sequence , Tuple , Union
7
+ import struct
8
+ from typing import Any , List , Mapping , NamedTuple , Optional , Sequence , Tuple , Union
8
9
9
10
from hypothesis import assume , reject
10
11
from hypothesis .strategies import (SearchStrategy , booleans , composite , floats ,
11
12
integers , just , lists , none , one_of ,
12
- sampled_from , shared )
13
+ sampled_from , shared , builds )
13
14
14
15
from . import _array_module as xp , api_version
15
16
from . import dtype_helpers as dh
20
21
from ._array_module import broadcast_to , eye , float32 , float64 , full
21
22
from .stubs import category_to_funcs
22
23
from .pytest_helpers import nargs
23
- from .typing import Array , DataType , Shape
24
+ from .typing import Array , DataType , Scalar , Shape
25
+
26
+
27
+ def _float32ify (n : Union [int , float ]) -> float :
28
+ n = float (n )
29
+ return struct .unpack ("!f" , struct .pack ("!f" , n ))[0 ]
30
+
31
+
32
+ @wraps (xps .from_dtype )
33
+ def from_dtype (dtype , ** kwargs ) -> SearchStrategy [Scalar ]:
34
+ """xps.from_dtype() without the crazy large numbers."""
35
+ if dtype == xp .bool :
36
+ return xps .from_dtype (dtype , ** kwargs )
37
+
38
+ if dtype in dh .complex_dtypes :
39
+ component_dtype = dh .dtype_components [dtype ]
40
+ else :
41
+ component_dtype = dtype
42
+
43
+ min_ , max_ = dh .dtype_ranges [component_dtype ]
44
+
45
+ if "min_value" not in kwargs .keys () and min_ != 0 :
46
+ assert min_ < 0 # sanity check
47
+ min_value = - 1 * math .floor (math .sqrt (abs (min_ )))
48
+ if component_dtype == xp .float32 :
49
+ min_value = _float32ify (min_value )
50
+ kwargs ["min_value" ] = min_value
51
+ if "max_value" not in kwargs .keys ():
52
+ assert max_ > 0 # sanity check
53
+ max_value = math .floor (math .sqrt (max_ ))
54
+ if component_dtype == xp .float32 :
55
+ max_value = _float32ify (max_value )
56
+ kwargs ["max_value" ] = max_value
57
+
58
+ if dtype in dh .complex_dtypes :
59
+ component_strat = xps .from_dtype (dh .dtype_components [dtype ], ** kwargs )
60
+ return builds (complex , component_strat , component_strat )
61
+ else :
62
+ return xps .from_dtype (dtype , ** kwargs )
63
+
64
+
65
+ @wraps (xps .arrays )
66
+ def arrays (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
67
+ """xps.arrays() without the crazy large numbers."""
68
+ if isinstance (dtype , SearchStrategy ):
69
+ return dtype .flatmap (lambda d : arrays (d , * args , elements = elements , ** kwargs ))
70
+
71
+ if elements is None :
72
+ elements = from_dtype (dtype )
73
+ elif isinstance (elements , Mapping ):
74
+ elements = from_dtype (dtype , ** elements )
75
+
76
+ return xps .arrays (dtype , * args , elements = elements , ** kwargs )
77
+
24
78
25
79
_dtype_categories = [(xp .bool ,), dh .uint_dtypes , dh .int_dtypes , dh .real_float_dtypes , dh .complex_dtypes ]
26
80
_sorted_dtypes = [d for category in _dtype_categories for d in category ]
@@ -145,7 +199,7 @@ def all_floating_dtypes() -> SearchStrategy[DataType]:
145
199
# Limit the total size of an array shape
146
200
MAX_ARRAY_SIZE = 10000
147
201
# Size to use for 2-dim arrays
148
- SQRT_MAX_ARRAY_SIZE = int (sqrt (MAX_ARRAY_SIZE ))
202
+ SQRT_MAX_ARRAY_SIZE = int (math . sqrt (MAX_ARRAY_SIZE ))
149
203
150
204
# np.prod and others have overflow and math.prod is Python 3.8+ only
151
205
def prod (seq ):
@@ -181,7 +235,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):
181
235
182
236
@composite
183
237
def finite_matrices (draw , shape = matrix_shapes ()):
184
- return draw (xps . arrays (dtype = xps .floating_dtypes (),
238
+ return draw (arrays (dtype = xps .floating_dtypes (),
185
239
shape = shape ,
186
240
elements = dict (allow_nan = False ,
187
241
allow_infinity = False )))
@@ -190,7 +244,7 @@ def finite_matrices(draw, shape=matrix_shapes()):
190
244
# Should we set a max_value here?
191
245
_rtol_float_kw = dict (allow_nan = False , allow_infinity = False , min_value = 0 )
192
246
rtols = one_of (floats (** _rtol_float_kw ),
193
- xps . arrays (dtype = xps .floating_dtypes (),
247
+ arrays (dtype = xps .floating_dtypes (),
194
248
shape = rtol_shared_matrix_shapes .map (lambda shape : shape [:- 2 ]),
195
249
elements = _rtol_float_kw ))
196
250
@@ -233,7 +287,7 @@ def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True):
233
287
if not isinstance (finite , bool ):
234
288
finite = draw (finite )
235
289
elements = {'allow_nan' : False , 'allow_infinity' : False } if finite else None
236
- a = draw (xps . arrays (dtype = dtype , shape = shape , elements = elements ))
290
+ a = draw (arrays (dtype = dtype , shape = shape , elements = elements ))
237
291
upper = xp .triu (a )
238
292
lower = xp .triu (a , k = 1 ).mT
239
293
return upper + lower
@@ -256,7 +310,7 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
256
310
n = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ),)
257
311
stack_shape = draw (stack_shapes )
258
312
shape = stack_shape + (n , n )
259
- d = draw (xps . arrays (dtypes , shape = n * prod (stack_shape ),
313
+ d = draw (arrays (dtypes , shape = n * prod (stack_shape ),
260
314
elements = dict (allow_nan = False , allow_infinity = False )))
261
315
# Functions that require invertible matrices may do anything when it is
262
316
# singular, including raising an exception, so we make sure the diagonals
@@ -282,7 +336,7 @@ def two_broadcastable_shapes(draw):
282
336
sizes = integers (0 , MAX_ARRAY_SIZE )
283
337
sqrt_sizes = integers (0 , SQRT_MAX_ARRAY_SIZE )
284
338
285
- numeric_arrays = xps . arrays (
339
+ numeric_arrays = arrays (
286
340
dtype = shared (xps .floating_dtypes (), key = 'dtypes' ),
287
341
shape = shared (xps .array_shapes (), key = 'shapes' ),
288
342
)
@@ -407,11 +461,11 @@ def two_mutual_arrays(
407
461
assert len (dtypes ) > 0 # sanity check
408
462
mutual_dtypes = shared (mutually_promotable_dtypes (dtypes = dtypes ))
409
463
mutual_shapes = shared (two_shapes )
410
- arrays1 = xps . arrays (
464
+ arrays1 = arrays (
411
465
dtype = mutual_dtypes .map (lambda pair : pair [0 ]),
412
466
shape = mutual_shapes .map (lambda pair : pair [0 ]),
413
467
)
414
- arrays2 = xps . arrays (
468
+ arrays2 = arrays (
415
469
dtype = mutual_dtypes .map (lambda pair : pair [1 ]),
416
470
shape = mutual_shapes .map (lambda pair : pair [1 ]),
417
471
)
0 commit comments