Skip to content

Commit c325be0

Browse files
committed
Add some TODO comments
1 parent 0334ced commit c325be0

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

array_api_tests/test_linalg.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def test_cross(x1_x2_kw):
131131

132132
res = _array_module.linalg.cross(x1, x2, **kw)
133133

134+
# TODO: Replace result_type() with a helper function
134135
assert res.dtype == _array_module.result_type(x1, x2), "cross() did not return the correct dtype"
135136
assert res.shape == shape, "cross() did not return the correct shape"
136137

@@ -253,6 +254,7 @@ def test_inv(x):
253254
*two_mutual_arrays(numeric_dtype_objects)
254255
)
255256
def test_matmul(x1, x2):
257+
# TODO: Make this also test the @ operator
256258
if (x1.shape == () or x2.shape == ()
257259
or len(x1.shape) == len(x2.shape) == 1 and x1.shape != x2.shape
258260
or len(x1.shape) == 1 and len(x2.shape) >= 2 and x1.shape[0] != x2.shape[-2]
@@ -266,6 +268,7 @@ def test_matmul(x1, x2):
266268
else:
267269
res = _array_module.linalg.matmul(x1, x2)
268270

271+
# TODO: Replace result_type() with a helper function
269272
assert res.dtype == _array_module.result_type(x1, x2), "matmul() did not return the correct dtype"
270273

271274
if len(x1.shape) == len(x2.shape) == 1:

0 commit comments

Comments
 (0)