Skip to content

Commit ec4f628

Browse files
committed
last linalg fixes
1 parent def6b46 commit ec4f628

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

array_api_compat/dask/array/linalg.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
if TYPE_CHECKING:
1818
from ...common._typing import Array
1919

20-
# cupy.linalg doesn't have __all__. If it is added, replace this with
20+
# dask.array.linalg doesn't have __all__. If it is added, replace this with
2121
#
22-
# from cupy.linalg import __all__ as linalg_all
22+
# from dask.array.linalg import __all__ as linalg_all
2323
_n = {}
2424
exec('from dask.array.linalg import *', _n)
2525
del _n['__builtins__']
@@ -32,7 +32,15 @@
3232
QRResult = _linalg.QRResult
3333
SlogdetResult = _linalg.SlogdetResult
3434
SVDResult = _linalg.SVDResult
35-
qr = get_xp(da)(_linalg.qr)
35+
# TODO: use the QR wrapper once dask
36+
# supports the mode keyword on QR
37+
# https://github.com/dask/dask/issues/10388
38+
#qr = get_xp(da)(_linalg.qr)
39+
def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced',
40+
**kwargs) -> QRResult:
41+
if mode != "reduced":
42+
raise ValueError("dask arrays only support using mode='reduced'")
43+
return QRResult(*da.linalg.qr(x, **kwargs))
3644
cholesky = get_xp(da)(_linalg.cholesky)
3745
matrix_rank = get_xp(da)(_linalg.matrix_rank)
3846
matrix_norm = get_xp(da)(_linalg.matrix_norm)
@@ -44,7 +52,7 @@
4452
def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult:
4553
if full_matrices:
4654
raise ValueError("full_matrics=True is not supported by dask.")
47-
return da.linalg.svd(x, **kwargs)
55+
return da.linalg.svd(x, coerce_signs=False, **kwargs)
4856

4957
def svdvals(x: Array) -> Array:
5058
# TODO: can't avoid computing U or V for dask

0 commit comments

Comments
 (0)