Skip to content

Make boolean_indexing a separate flag from data_dependent_shapes #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,8 @@ def _validate_index(self, key):
f"{len(key)=}, but masking is only specified in the "
"Array API when the array is the sole index."
)
if not get_array_api_strict_flags()['data_dependent_shapes']:
raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict")
if not get_array_api_strict_flags()['boolean_indexing']:
raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the boolean_indexing flag has been disabled for array-api-strict")

elif i.dtype in _integer_dtypes and i.ndim != 0:
raise IndexError(
Expand Down
42 changes: 34 additions & 8 deletions array_api_strict/_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

API_VERSION = default_version = "2022.12"

BOOLEAN_INDEXING = True

DATA_DEPENDENT_SHAPES = True

all_extensions = (
Expand All @@ -46,6 +48,7 @@
def set_array_api_strict_flags(
*,
api_version=None,
boolean_indexing=None,
data_dependent_shapes=None,
enabled_extensions=None,
):
Expand All @@ -67,6 +70,12 @@ def set_array_api_strict_flags(
Note that 2021.12 is supported, but currently gives the same thing as
2022.12 (except that the fft extension will be disabled).


- `boolean_indexing`: Whether indexing by a boolean array is supported.
Note that although boolean array indexing does result in data-dependent
shapes, this flag is independent of the `data_dependent_shapes` flag
(see below).

- `data_dependent_shapes`: Whether data-dependent shapes are enabled in
array-api-strict.

Expand All @@ -79,10 +88,12 @@ def set_array_api_strict_flags(

- `unique_all`, `unique_counts`, `unique_inverse`, and `unique_values`.
- `nonzero`
- Boolean array indexing
- `repeat` when the `repeats` argument is an array (requires 2023.12
version of the standard)

Note that while boolean indexing is also data-dependent, it is
controlled by a separate `boolean_indexing` flag (see above).

See
https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html
for more details.
Expand All @@ -102,8 +113,8 @@ def set_array_api_strict_flags(
>>> # Set the standard version to 2021.12
>>> set_array_api_strict_flags(api_version="2021.12")

>>> # Disable data-dependent shapes
>>> set_array_api_strict_flags(data_dependent_shapes=False)
>>> # Disable data-dependent shapes and boolean indexing
>>> set_array_api_strict_flags(data_dependent_shapes=False, boolean_indexing=False)

>>> # Enable only the linalg extension (disable the fft extension)
>>> set_array_api_strict_flags(enabled_extensions=["linalg"])
Expand All @@ -116,7 +127,7 @@ def set_array_api_strict_flags(
ArrayAPIStrictFlags: A context manager to temporarily set the flags.

"""
global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
global API_VERSION, BOOLEAN_INDEXING, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS

if api_version is not None:
if api_version not in supported_versions:
Expand All @@ -126,6 +137,9 @@ def set_array_api_strict_flags(
API_VERSION = api_version
array_api_strict.__array_api_version__ = API_VERSION

if boolean_indexing is not None:
BOOLEAN_INDEXING = boolean_indexing

if data_dependent_shapes is not None:
DATA_DEPENDENT_SHAPES = data_dependent_shapes

Expand Down Expand Up @@ -169,7 +183,11 @@ def get_array_api_strict_flags():
>>> from array_api_strict import get_array_api_strict_flags
>>> flags = get_array_api_strict_flags()
>>> flags
{'api_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')}
{'api_version': '2022.12',
'boolean_indexing': True,
'data_dependent_shapes': True,
'enabled_extensions': ('linalg', 'fft')
}

See Also
--------
Expand All @@ -181,6 +199,7 @@ def get_array_api_strict_flags():
"""
return {
"api_version": API_VERSION,
"boolean_indexing": BOOLEAN_INDEXING,
"data_dependent_shapes": DATA_DEPENDENT_SHAPES,
"enabled_extensions": ENABLED_EXTENSIONS,
}
Expand Down Expand Up @@ -215,9 +234,10 @@ def reset_array_api_strict_flags():
ArrayAPIStrictFlags: A context manager to temporarily set the flags.

"""
global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
global API_VERSION, BOOLEAN_INDEXING, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
API_VERSION = default_version
array_api_strict.__array_api_version__ = API_VERSION
BOOLEAN_INDEXING = True
DATA_DEPENDENT_SHAPES = True
ENABLED_EXTENSIONS = default_extensions

Expand All @@ -242,10 +262,11 @@ class ArrayAPIStrictFlags:
reset_array_api_strict_flags: Reset the flags to their default values.

"""
def __init__(self, *, api_version=None, data_dependent_shapes=None,
enabled_extensions=None):
def __init__(self, *, api_version=None, boolean_indexing=None,
data_dependent_shapes=None, enabled_extensions=None):
self.kwargs = {
"api_version": api_version,
"boolean_indexing": boolean_indexing,
"data_dependent_shapes": data_dependent_shapes,
"enabled_extensions": enabled_extensions,
}
Expand All @@ -265,6 +286,11 @@ def set_flags_from_environment():
api_version=os.environ["ARRAY_API_STRICT_API_VERSION"]
)

if "ARRAY_API_STRICT_BOOLEAN_INDEXING" in os.environ:
set_array_api_strict_flags(
boolean_indexing=os.environ["ARRAY_API_STRICT_BOOLEAN_INDEXING"].lower() == "true"
)

if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ:
set_array_api_strict_flags(
data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true"
Expand Down
17 changes: 17 additions & 0 deletions array_api_strict/tests/test_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_flags():
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2022.12',
'boolean_indexing': True,
'data_dependent_shapes': True,
'enabled_extensions': ('linalg', 'fft'),
}
Expand All @@ -22,13 +23,15 @@ def test_flags():
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2022.12',
'boolean_indexing': True,
'data_dependent_shapes': False,
'enabled_extensions': ('linalg', 'fft'),
}
set_array_api_strict_flags(enabled_extensions=('fft',))
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2022.12',
'boolean_indexing': True,
'data_dependent_shapes': False,
'enabled_extensions': ('fft',),
}
Expand All @@ -41,6 +44,7 @@ def test_flags():
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2021.12',
'boolean_indexing': True,
'data_dependent_shapes': False,
'enabled_extensions': ('linalg',),
}
Expand All @@ -58,12 +62,14 @@ def test_flags():
with pytest.warns(UserWarning):
set_array_api_strict_flags(
api_version='2021.12',
boolean_indexing=False,
data_dependent_shapes=False,
enabled_extensions=())
reset_array_api_strict_flags()
flags = get_array_api_strict_flags()
assert flags == {
'api_version': '2022.12',
'boolean_indexing': True,
'data_dependent_shapes': True,
'enabled_extensions': ('linalg', 'fft'),
}
Expand Down Expand Up @@ -96,6 +102,17 @@ def test_data_dependent_shapes():
pytest.raises(RuntimeError, lambda: unique_inverse(a))
pytest.raises(RuntimeError, lambda: unique_values(a))
pytest.raises(RuntimeError, lambda: nonzero(a))
a[mask] # No error (boolean indexing is a separate flag)

def test_boolean_indexing():
a = asarray([0, 0, 1, 2, 2])
mask = asarray([True, False, True, False, True])

# Should not error
a[mask]

set_array_api_strict_flags(boolean_indexing=False)

pytest.raises(RuntimeError, lambda: a[mask])

linalg_examples = {
Expand Down
4 changes: 4 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ used by array-api-strict initially. They will not change the defaults used by

A string representing the version number.

.. envvar:: ARRAY_API_STRICT_BOOLEAN_INDEXING

"True" or "False" to enable or disable boolean indexing.

.. envvar:: ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES

"True" or "False" to enable or disable data dependent shapes.
Expand Down