Skip to content

Commit f6f660c

Browse files
Add dpnp.linalg.lstsq() implementation (#1792)
* Add dpnp.linalg.lstsq impl via svd call * Add cupy tests * Add copy of s to compute follows data * Add tests in test_sycl_queue and test_usm_type * Add related_arrays param to dpnp_svd/svd_batch to follow compute follows data * Move above and add docstings for _nrm2_last_axis * Add dpnp tests * Fix codespell check * Get uv_type given related_arrays in dpnp_svd * Check rcond type * Use empty_like in dpnp_svd * rcond as int, float or None * Unlock fix_random() for TestRandint2 --------- Co-authored-by: Anton <[email protected]>
1 parent 8a378b4 commit f6f660c

File tree

7 files changed

+506
-55
lines changed

7 files changed

+506
-55
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@
4242
import dpnp
4343

4444
from .dpnp_utils_linalg import (
45+
check_2d,
4546
check_stacked_2d,
4647
check_stacked_square,
4748
dpnp_cholesky,
4849
dpnp_cond,
4950
dpnp_det,
5051
dpnp_eigh,
5152
dpnp_inv,
53+
dpnp_lstsq,
5254
dpnp_matrix_power,
5355
dpnp_matrix_rank,
5456
dpnp_multi_dot,
@@ -69,6 +71,7 @@
6971
"eigvals",
7072
"eigvalsh",
7173
"inv",
74+
"lstsq",
7275
"matrix_power",
7376
"matrix_rank",
7477
"multi_dot",
@@ -594,6 +597,79 @@ def inv(a):
594597
return dpnp_inv(a)
595598

596599

600+
def lstsq(a, b, rcond=None):
601+
"""
602+
Return the least-squares solution to a linear matrix equation.
603+
604+
For full documentation refer to :obj:`numpy.linalg.lstsq`.
605+
606+
Parameters
607+
----------
608+
a : (M, N) {dpnp.ndarray, usm_ndarray}
609+
"Coefficient" matrix.
610+
b : {(M,), (M, K)} {dpnp.ndarray, usm_ndarray}
611+
Ordinate or "dependent variable" values.
612+
If `b` is two-dimensional, the least-squares solution
613+
is calculated for each of the `K` columns of `b`.
614+
rcond : {int, float, None}, optional
615+
Cut-off ratio for small singular values of `a`.
616+
For the purposes of rank determination, singular values are treated
617+
as zero if they are smaller than `rcond` times the largest singular
618+
value of `a`.
619+
The default uses the machine precision times ``max(M, N)``. Passing
620+
``-1`` will use machine precision.
621+
622+
Returns
623+
-------
624+
x : {(N,), (N, K)} dpnp.ndarray
625+
Least-squares solution. If `b` is two-dimensional,
626+
the solutions are in the `K` columns of `x`.
627+
residuals : {(1,), (K,), (0,)} dpnp.ndarray
628+
Sums of squared residuals: Squared Euclidean 2-norm for each column in
629+
``b - a @ x``.
630+
If the rank of `a` is < N or M <= N, this is an empty array.
631+
If `b` is 1-dimensional, this is a (1,) shape array.
632+
Otherwise the shape is (K,).
633+
rank : int
634+
Rank of matrix `a`.
635+
s : (min(M, N),) dpnp.ndarray
636+
Singular values of `a`.
637+
638+
Examples
639+
--------
640+
Fit a line, ``y = mx + c``, through some noisy data-points:
641+
642+
>>> import dpnp as np
643+
>>> x = np.array([0, 1, 2, 3])
644+
>>> y = np.array([-1, 0.2, 0.9, 2.1])
645+
646+
By examining the coefficients, we see that the line should have a
647+
gradient of roughly 1 and cut the y-axis at, more or less, -1.
648+
649+
We can rewrite the line equation as ``y = Ap``, where ``A = [[x 1]]``
650+
and ``p = [[m], [c]]``. Now use `lstsq` to solve for `p`:
651+
652+
>>> A = np.vstack([x, np.ones(len(x))]).T
653+
>>> A
654+
array([[0., 1.],
655+
[1., 1.],
656+
[2., 1.],
657+
[3., 1.]])
658+
659+
>>> m, c = np.linalg.lstsq(A, y, rcond=None)[0]
660+
>>> m, c
661+
(array(1.), array(-0.95)) # may vary
662+
663+
"""
664+
665+
dpnp.check_supported_arrays_type(a, b)
666+
check_2d(a)
667+
if rcond is not None and not isinstance(rcond, (int, float)):
668+
raise TypeError("rcond must be integer, floating type, or None")
669+
670+
return dpnp_lstsq(a, b, rcond=rcond)
671+
672+
597673
def matrix_power(a, n):
598674
"""
599675
Raise a square matrix to the (integer) power `n`.

0 commit comments

Comments
 (0)