Skip to content

Commit 50dff4c

Browse files
34jnstarman
andcommitted
fix: array to Array
Co-authored-by: Nathaniel Starkman <[email protected]>
1 parent 8a7a43e commit 50dff4c

File tree

5 files changed

+58
-39
lines changed

5 files changed

+58
-39
lines changed

.pre-commit-config.yaml

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,29 @@ repos:
4848
rev: 23.7.0
4949
hooks:
5050
- id: black
51+
5152
- repo: https://github.com/pre-commit/mirrors-mypy
52-
rev: v1.13.0
53+
rev: "v1.0.0"
5354
hooks:
5455
- id: mypy
55-
additional_dependencies: []
56-
args: [--ignore-missing-imports, .]
57-
pass_filenames: false
58-
exclude: "src/array_api_stubs/.*"
56+
additional_dependencies: [typing_extensions>=4.4.0]
57+
args:
58+
- --ignore-missing-imports
59+
- --config=pyproject.toml
60+
files: ".*(_draft.*)$"
61+
exclude: |
62+
(?x)^(
63+
.*creation_functions.py|
64+
.*data_type_functions.py|
65+
.*elementwise_functions.py|
66+
.*fft.py|
67+
.*indexing_functions.py|
68+
.*linalg.py|
69+
.*linear_algebra_functions.py|
70+
.*manipulation_functions.py|
71+
.*searching_functions.py|
72+
.*set_functions.py|
73+
.*sorting_functions.py|
74+
.*statistical_functions.py|
75+
.*utility_functions.py|
76+
)$

pyproject.toml

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,10 @@ build-backend = "setuptools.build_meta"
3434
line-length = 88
3535

3636
[tool.mypy]
37-
exclude = [
38-
"docs/.*",
39-
"spec/.*",
40-
"venv/.*",
41-
".venv/.*",
42-
"src/array_api_stubs/_2021_12/.*",
43-
"src/array_api_stubs/_2022_12/.*",
44-
"src/array_api_stubs/_2023_12/.*",
45-
"src/_array_api_conf.py"
37+
python_version = "3.9"
38+
mypy_path = "$MYPY_CONFIG_FILE_DIR/src/array_api_stubs/_draft/"
39+
files = [
40+
"src/array_api_stubs/_draft/**/*.py"
4641
]
42+
follow_imports = "silent"
4743
disable_error_code = "empty-body,type-var"

src/_array_api_conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@
6666
]
6767
nitpick_ignore_regex = [
6868
("py:class", ".*array"),
69+
("py:class", ".*Array"),
6970
("py:class", ".*device"),
71+
("py:class", ".*Device"),
7072
("py:class", ".*dtype"),
7173
("py:class", ".*NestedSequence"),
7274
("py:class", ".*SupportsBufferProtocol"),
@@ -84,6 +86,7 @@
8486
"array": "array",
8587
"Device": "device",
8688
"Dtype": "dtype",
89+
"DType": "dtype",
8790
}
8891

8992
# Make autosummary show the signatures of functions in the tables using actual

src/array_api_stubs/_draft/_types.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
"Info",
3232
]
3333

34-
from dataclasses import dataclass
3534
from typing import (
3635
Any,
3736
List,
@@ -45,12 +44,13 @@
4544
Protocol,
4645
)
4746
from enum import Enum
47+
from .data_types import DType
4848

4949
array = TypeVar("array", bound="_array")
5050
device = TypeVar("device")
51-
dtype = TypeVar("dtype")
52-
Device = TypeVar("Device")
53-
Dtype = TypeVar("Dtype")
51+
dtype = TypeVar("dtype", bound=DType)
52+
device_ = TypeVar("device_") # only used in this file
53+
dtype_ = TypeVar("dtype_", bound=DType) # only used in this file
5454
SupportsDLPack = TypeVar("SupportsDLPack")
5555
SupportsBufferProtocol = TypeVar("SupportsBufferProtocol")
5656
PyCapsule = TypeVar("PyCapsule")
@@ -149,12 +149,12 @@ def dtypes(
149149
)
150150

151151

152-
class _array(Protocol[array, Dtype, Device, PyCapsule]): # type: ignore
152+
class _array(Protocol[array, dtype_, device_, PyCapsule]): # type: ignore
153153
def __init__(self: array) -> None:
154154
"""Initialize the attributes for the array object class."""
155155

156156
@property
157-
def dtype(self: array) -> Dtype:
157+
def dtype(self: array) -> dtype_:
158158
"""
159159
Data type of the array elements.
160160
@@ -165,7 +165,7 @@ def dtype(self: array) -> Dtype:
165165
"""
166166

167167
@property
168-
def device(self: array) -> Device:
168+
def device(self: array) -> device_:
169169
"""
170170
Hardware device the array data resides on.
171171
@@ -1344,7 +1344,7 @@ def __xor__(self: array, other: Union[int, bool, array], /) -> array:
13441344
"""
13451345

13461346
def to_device(
1347-
self: array, device: Device, /, *, stream: Optional[Union[int, Any]] = None
1347+
self: array, device: device_, /, *, stream: Optional[Union[int, Any]] = None
13481348
) -> array:
13491349
"""
13501350
Copy the array from the device on which it currently resides to the specified ``device``.
Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
1-
__all__ = ["__eq__"]
1+
from __future__ import annotations
22

3+
__all__ = ["DType"]
34

4-
from ._types import dtype
55

6+
from typing import Protocol
67

7-
def __eq__(self: dtype, other: dtype, /) -> bool:
8-
"""
9-
Computes the truth value of ``self == other`` in order to test for data type object equality.
108

11-
Parameters
12-
----------
13-
self: dtype
14-
data type instance. May be any supported data type.
15-
other: dtype
16-
other data type instance. May be any supported data type.
17-
18-
Returns
19-
-------
20-
out: bool
21-
a boolean indicating whether the data type objects are equal.
22-
"""
9+
class DType(Protocol):
10+
def __eq__(self, other: DType, /) -> bool:
11+
"""
12+
Computes the truth value of ``self == other`` in order to test for data type object equality.
13+
Parameters
14+
----------
15+
self: dtype
16+
data type instance. May be any supported data type.
17+
other: dtype
18+
other data type instance. May be any supported data type.
19+
Returns
20+
-------
21+
out: bool
22+
a boolean indicating whether the data type objects are equal.
23+
"""
24+
...

0 commit comments

Comments
 (0)