|
17 | 17 | if TYPE_CHECKING:
|
18 | 18 | from ...common._typing import Array
|
19 | 19 |
|
20 |
| -# cupy.linalg doesn't have __all__. If it is added, replace this with |
| 20 | +# dask.array.linalg doesn't have __all__. If it is added, replace this with |
21 | 21 | #
|
22 |
| -# from cupy.linalg import __all__ as linalg_all |
| 22 | +# from dask.array.linalg import __all__ as linalg_all |
23 | 23 | _n = {}
|
24 | 24 | exec('from dask.array.linalg import *', _n)
|
25 | 25 | del _n['__builtins__']
|
|
32 | 32 | QRResult = _linalg.QRResult
|
33 | 33 | SlogdetResult = _linalg.SlogdetResult
|
34 | 34 | SVDResult = _linalg.SVDResult
|
35 |
| -qr = get_xp(da)(_linalg.qr) |
| 35 | +# TODO: use the QR wrapper once dask |
| 36 | +# supports the mode keyword on QR |
| 37 | +# https://github.com/dask/dask/issues/10388 |
| 38 | +#qr = get_xp(da)(_linalg.qr) |
| 39 | +def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', |
| 40 | + **kwargs) -> QRResult: |
| 41 | + if mode != "reduced": |
| 42 | + raise ValueError("dask arrays only support using mode='reduced'") |
| 43 | + return QRResult(*da.linalg.qr(x, **kwargs)) |
36 | 44 | cholesky = get_xp(da)(_linalg.cholesky)
|
37 | 45 | matrix_rank = get_xp(da)(_linalg.matrix_rank)
|
38 | 46 | matrix_norm = get_xp(da)(_linalg.matrix_norm)
|
|
44 | 52 | def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult:
|
45 | 53 | if full_matrices:
|
46 | 54 | raise ValueError("full_matrics=True is not supported by dask.")
|
47 |
| - return da.linalg.svd(x, **kwargs) |
| 55 | + return da.linalg.svd(x, coerce_signs=False, **kwargs) |
48 | 56 |
|
49 | 57 | def svdvals(x: Array) -> Array:
|
50 | 58 | # TODO: can't avoid computing U or V for dask
|
|
0 commit comments