1
1
import re
2
2
import itertools
3
3
from contextlib import contextmanager
4
- from functools import reduce
5
- from math import sqrt
4
+ from functools import reduce , wraps
5
+ import math
6
6
from operator import mul
7
- from typing import Any , List , NamedTuple , Optional , Sequence , Tuple , Union
7
+ import struct
8
+ from typing import Any , List , Mapping , NamedTuple , Optional , Sequence , Tuple , Union
8
9
9
10
from hypothesis import assume , reject
10
11
from hypothesis .strategies import (SearchStrategy , booleans , composite , floats ,
11
12
integers , just , lists , none , one_of ,
12
- sampled_from , shared )
13
+ sampled_from , shared , builds )
13
14
14
15
from . import _array_module as xp , api_version
15
16
from . import dtype_helpers as dh
20
21
from ._array_module import broadcast_to , eye , float32 , float64 , full
21
22
from .stubs import category_to_funcs
22
23
from .pytest_helpers import nargs
23
- from .typing import Array , DataType , Shape
24
-
25
- # Set this to True to not fail tests just because a dtype isn't implemented.
26
- # If no compatible dtype is implemented for a given test, the test will fail
27
- # with a hypothesis health check error. Note that this functionality will not
28
- # work for floating point dtypes as those are assumed to be defined in other
29
- # places in the tests.
30
- FILTER_UNDEFINED_DTYPES = True
31
- # TODO: currently we assume this to be true - we probably can remove this completely
32
- assert FILTER_UNDEFINED_DTYPES
33
-
34
- integer_dtypes = xps .integer_dtypes () | xps .unsigned_integer_dtypes ()
35
- floating_dtypes = xps .floating_dtypes ()
36
- numeric_dtypes = xps .numeric_dtypes ()
37
- integer_or_boolean_dtypes = xps .boolean_dtypes () | integer_dtypes
38
- boolean_dtypes = xps .boolean_dtypes ()
39
- dtypes = xps .scalar_dtypes ()
40
-
41
- shared_dtypes = shared (dtypes , key = "dtype" )
42
- shared_floating_dtypes = shared (floating_dtypes , key = "dtype" )
24
+ from .typing import Array , DataType , Scalar , Shape
25
+
26
+
27
+ def _float32ify (n : Union [int , float ]) -> float :
28
+ n = float (n )
29
+ return struct .unpack ("!f" , struct .pack ("!f" , n ))[0 ]
30
+
31
+
32
+ @wraps (xps .from_dtype )
33
+ def from_dtype (dtype , ** kwargs ) -> SearchStrategy [Scalar ]:
34
+ """xps.from_dtype() without the crazy large numbers."""
35
+ if dtype == xp .bool :
36
+ return xps .from_dtype (dtype , ** kwargs )
37
+
38
+ if dtype in dh .complex_dtypes :
39
+ component_dtype = dh .dtype_components [dtype ]
40
+ else :
41
+ component_dtype = dtype
42
+
43
+ min_ , max_ = dh .dtype_ranges [component_dtype ]
44
+
45
+ if "min_value" not in kwargs .keys () and min_ != 0 :
46
+ assert min_ < 0 # sanity check
47
+ min_value = - 1 * math .floor (math .sqrt (abs (min_ )))
48
+ if component_dtype == xp .float32 :
49
+ min_value = _float32ify (min_value )
50
+ kwargs ["min_value" ] = min_value
51
+ if "max_value" not in kwargs .keys ():
52
+ assert max_ > 0 # sanity check
53
+ max_value = math .floor (math .sqrt (max_ ))
54
+ if component_dtype == xp .float32 :
55
+ max_value = _float32ify (max_value )
56
+ kwargs ["max_value" ] = max_value
57
+
58
+ if dtype in dh .complex_dtypes :
59
+ component_strat = xps .from_dtype (dh .dtype_components [dtype ], ** kwargs )
60
+ return builds (complex , component_strat , component_strat )
61
+ else :
62
+ return xps .from_dtype (dtype , ** kwargs )
63
+
64
+
65
+ @wraps (xps .arrays )
66
+ def arrays (dtype , * args , elements = None , ** kwargs ) -> SearchStrategy [Array ]:
67
+ """xps.arrays() without the crazy large numbers."""
68
+ if isinstance (dtype , SearchStrategy ):
69
+ return dtype .flatmap (lambda d : arrays (d , * args , elements = elements , ** kwargs ))
70
+
71
+ if elements is None :
72
+ elements = from_dtype (dtype )
73
+ elif isinstance (elements , Mapping ):
74
+ elements = from_dtype (dtype , ** elements )
75
+
76
+ return xps .arrays (dtype , * args , elements = elements , ** kwargs )
77
+
43
78
44
79
_dtype_categories = [(xp .bool ,), dh .uint_dtypes , dh .int_dtypes , dh .real_float_dtypes , dh .complex_dtypes ]
45
80
_sorted_dtypes = [d for category in _dtype_categories for d in category ]
@@ -62,21 +97,19 @@ def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
62
97
return key
63
98
64
99
_promotable_dtypes = list (dh .promotion_table .keys ())
65
- if FILTER_UNDEFINED_DTYPES :
66
- _promotable_dtypes = [
67
- (d1 , d2 ) for d1 , d2 in _promotable_dtypes
68
- if not isinstance (d1 , _UndefinedStub ) or not isinstance (d2 , _UndefinedStub )
69
- ]
100
+ _promotable_dtypes = [
101
+ (d1 , d2 ) for d1 , d2 in _promotable_dtypes
102
+ if not isinstance (d1 , _UndefinedStub ) or not isinstance (d2 , _UndefinedStub )
103
+ ]
70
104
promotable_dtypes : List [Tuple [DataType , DataType ]] = sorted (_promotable_dtypes , key = _dtypes_sorter )
71
105
72
106
def mutually_promotable_dtypes (
73
107
max_size : Optional [int ] = 2 ,
74
108
* ,
75
109
dtypes : Sequence [DataType ] = dh .all_dtypes ,
76
110
) -> SearchStrategy [Tuple [DataType , ...]]:
77
- if FILTER_UNDEFINED_DTYPES :
78
- dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
79
- assert len (dtypes ) > 0 , "all dtypes undefined" # sanity check
111
+ dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
112
+ assert len (dtypes ) > 0 , "all dtypes undefined" # sanity check
80
113
if max_size == 2 :
81
114
return sampled_from (
82
115
[(i , j ) for i , j in promotable_dtypes if i in dtypes and j in dtypes ]
@@ -166,7 +199,7 @@ def all_floating_dtypes() -> SearchStrategy[DataType]:
166
199
# Limit the total size of an array shape
167
200
MAX_ARRAY_SIZE = 10000
168
201
# Size to use for 2-dim arrays
169
- SQRT_MAX_ARRAY_SIZE = int (sqrt (MAX_ARRAY_SIZE ))
202
+ SQRT_MAX_ARRAY_SIZE = int (math . sqrt (MAX_ARRAY_SIZE ))
170
203
171
204
# np.prod and others have overflow and math.prod is Python 3.8+ only
172
205
def prod (seq ):
@@ -202,7 +235,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):
202
235
203
236
@composite
204
237
def finite_matrices (draw , shape = matrix_shapes ()):
205
- return draw (xps . arrays (dtype = xps .floating_dtypes (),
238
+ return draw (arrays (dtype = xps .floating_dtypes (),
206
239
shape = shape ,
207
240
elements = dict (allow_nan = False ,
208
241
allow_infinity = False )))
@@ -211,7 +244,7 @@ def finite_matrices(draw, shape=matrix_shapes()):
211
244
# Should we set a max_value here?
212
245
_rtol_float_kw = dict (allow_nan = False , allow_infinity = False , min_value = 0 )
213
246
rtols = one_of (floats (** _rtol_float_kw ),
214
- xps . arrays (dtype = xps .floating_dtypes (),
247
+ arrays (dtype = xps .floating_dtypes (),
215
248
shape = rtol_shared_matrix_shapes .map (lambda shape : shape [:- 2 ]),
216
249
elements = _rtol_float_kw ))
217
250
@@ -254,7 +287,7 @@ def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True):
254
287
if not isinstance (finite , bool ):
255
288
finite = draw (finite )
256
289
elements = {'allow_nan' : False , 'allow_infinity' : False } if finite else None
257
- a = draw (xps . arrays (dtype = dtype , shape = shape , elements = elements ))
290
+ a = draw (arrays (dtype = dtype , shape = shape , elements = elements ))
258
291
upper = xp .triu (a )
259
292
lower = xp .triu (a , k = 1 ).mT
260
293
return upper + lower
@@ -277,7 +310,7 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
277
310
n = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ),)
278
311
stack_shape = draw (stack_shapes )
279
312
shape = stack_shape + (n , n )
280
- d = draw (xps . arrays (dtypes , shape = n * prod (stack_shape ),
313
+ d = draw (arrays (dtypes , shape = n * prod (stack_shape ),
281
314
elements = dict (allow_nan = False , allow_infinity = False )))
282
315
# Functions that require invertible matrices may do anything when it is
283
316
# singular, including raising an exception, so we make sure the diagonals
@@ -303,7 +336,7 @@ def two_broadcastable_shapes(draw):
303
336
sizes = integers (0 , MAX_ARRAY_SIZE )
304
337
sqrt_sizes = integers (0 , SQRT_MAX_ARRAY_SIZE )
305
338
306
- numeric_arrays = xps . arrays (
339
+ numeric_arrays = arrays (
307
340
dtype = shared (xps .floating_dtypes (), key = 'dtypes' ),
308
341
shape = shared (xps .array_shapes (), key = 'shapes' ),
309
342
)
@@ -348,7 +381,7 @@ def python_integer_indices(draw, sizes):
348
381
def integer_indices (draw , sizes ):
349
382
# Return either a Python integer or a 0-D array with some integer dtype
350
383
idx = draw (python_integer_indices (sizes ))
351
- dtype = draw (integer_dtypes )
384
+ dtype = draw (xps . integer_dtypes () | xps . unsigned_integer_dtypes () )
352
385
m , M = dh .dtype_ranges [dtype ]
353
386
if m <= idx <= M :
354
387
return draw (one_of (just (idx ),
@@ -424,16 +457,15 @@ def two_mutual_arrays(
424
457
) -> Tuple [SearchStrategy [Array ], SearchStrategy [Array ]]:
425
458
if not isinstance (dtypes , Sequence ):
426
459
raise TypeError (f"{ dtypes = } not a sequence" )
427
- if FILTER_UNDEFINED_DTYPES :
428
- dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
429
- assert len (dtypes ) > 0 # sanity check
460
+ dtypes = [d for d in dtypes if not isinstance (d , _UndefinedStub )]
461
+ assert len (dtypes ) > 0 # sanity check
430
462
mutual_dtypes = shared (mutually_promotable_dtypes (dtypes = dtypes ))
431
463
mutual_shapes = shared (two_shapes )
432
- arrays1 = xps . arrays (
464
+ arrays1 = arrays (
433
465
dtype = mutual_dtypes .map (lambda pair : pair [0 ]),
434
466
shape = mutual_shapes .map (lambda pair : pair [0 ]),
435
467
)
436
- arrays2 = xps . arrays (
468
+ arrays2 = arrays (
437
469
dtype = mutual_dtypes .map (lambda pair : pair [1 ]),
438
470
shape = mutual_shapes .map (lambda pair : pair [1 ]),
439
471
)
0 commit comments