Skip to content

Commit d4edeef

Browse files
Merge 0293a3f into 0457fe1
2 parents 0457fe1 + 0293a3f commit d4edeef

File tree

6 files changed

+329
-1
lines changed

6 files changed

+329
-1
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
dpnp_det,
5252
dpnp_eigh,
5353
dpnp_inv,
54+
dpnp_pinv,
5455
dpnp_qr,
5556
dpnp_slogdet,
5657
dpnp_solve,
@@ -69,6 +70,7 @@
6970
"matrix_rank",
7071
"multi_dot",
7172
"norm",
73+
"pinv",
7274
"qr",
7375
"solve",
7476
"svd",
@@ -474,6 +476,55 @@ def multi_dot(arrays, out=None):
474476
return result
475477

476478

479+
def pinv(a, rcond=1e-15, hermitian=False):
480+
"""
481+
Compute the (Moore-Penrose) pseudo-inverse of a matrix.
482+
483+
Calculate the generalized inverse of a matrix using its
484+
singular-value decomposition (SVD) and including all large singular values.
485+
486+
For full documentation refer to :obj:`numpy.linalg.inv`.
487+
488+
Parameters
489+
----------
490+
a : (..., M, N) {dpnp.ndarray, usm_ndarray}
491+
Matrix or stack of matrices to be pseudo-inverted.
492+
rcond : float or dpnp.ndarray of float, optional
493+
Cutoff for small singular values.
494+
Singular values less than or equal to ``rcond * largest_singular_value``
495+
are set to zero.
496+
Default: ``1e-15``.
497+
hermitian : bool, optional
498+
If ``True``, a is assumed to be Hermitian (symmetric if real-valued),
499+
enabling a more efficient method for finding singular values.
500+
Default: ``False``.
501+
502+
Returns
503+
-------
504+
out : (..., N, M) dpnp.ndarray
505+
The pseudo-inverse of a.
506+
507+
Examples
508+
--------
509+
The following example checks that ``a * a+ * a == a`` and
510+
``a+ * a * a+ == a+``:
511+
512+
>>> import dpnp as np
513+
>>> a = np.random.randn(9, 6)
514+
>>> B = np.linalg.pinv(a)
515+
>>> np.allclose(a, np.dot(a, np.dot(B, a)))
516+
array([ True])
517+
>>> np.allclose(B, np.dot(B, np.dot(a, B)))
518+
array([ True])
519+
520+
"""
521+
522+
dpnp.check_supported_arrays_type(a)
523+
check_stacked_2d(a)
524+
525+
return dpnp_pinv(a, rcond=rcond, hermitian=hermitian)
526+
527+
477528
def norm(x1, ord=None, axis=None, keepdims=False):
478529
"""
479530
Matrix or vector norm.

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"dpnp_det",
4040
"dpnp_eigh",
4141
"dpnp_inv",
42+
"dpnp_pinv",
4243
"dpnp_qr",
4344
"dpnp_slogdet",
4445
"dpnp_solve",
@@ -998,6 +999,36 @@ def dpnp_inv(a):
998999
return b_f
9991000

10001001

1002+
def dpnp_pinv(a, rcond=1e-15, hermitian=False):
1003+
"""
1004+
dpnp_pinv(a, rcond=1e-15, hermitian=False):
1005+
1006+
Compute the Moore-Penrose pseudoinverse of `a` matrix.
1007+
1008+
It computes a pseudoinverse of a matrix `a`, which is a generalization
1009+
of the inverse matrix with Singular Value Decomposition (SVD).
1010+
1011+
"""
1012+
1013+
rcond = dpnp.array(rcond, device=a.sycl_device, sycl_queue=a.sycl_queue)
1014+
if a.size == 0:
1015+
res_type = _common_type(a)
1016+
m, n = a.shape[-2:]
1017+
if m == 0 or n == 0:
1018+
res_type = a.dtype
1019+
return dpnp.empty_like(a, shape=(a.shape[:-2] + (n, m)), dtype=res_type)
1020+
1021+
u, s, vt = dpnp_svd(a.conj(), full_matrices=False, hermitian=hermitian)
1022+
1023+
# discard small singular values
1024+
cutoff = rcond * dpnp.amax(s, axis=-1)
1025+
leq = s <= cutoff[..., None]
1026+
dpnp.reciprocal(s, out=s)
1027+
s[leq] = 0
1028+
1029+
return dpnp.matmul(vt.swapaxes(-2, -1), s[..., None] * u.swapaxes(-2, -1))
1030+
1031+
10011032
def dpnp_qr_batch(a, mode="reduced"):
10021033
"""
10031034
dpnp_qr_batch(a, mode="reduced")

tests/test_linalg.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,3 +1167,132 @@ def test_svd_errors(self):
11671167
# a.ndim < 2
11681168
a_dp_ndim_1 = a_dp.flatten()
11691169
assert_raises(inp.linalg.LinAlgError, inp.linalg.svd, a_dp_ndim_1)
1170+
1171+
1172+
class TestPinv:
1173+
def get_tol(self, dtype):
1174+
tol = 1e-06
1175+
if dtype in (inp.float32, inp.complex64):
1176+
tol = 1e-04
1177+
elif not has_support_aspect64() and dtype in (
1178+
inp.int32,
1179+
inp.int64,
1180+
None,
1181+
):
1182+
tol = 1e-05
1183+
self._tol = tol
1184+
1185+
def check_types_shapes(self, dp_B, np_B):
1186+
if has_support_aspect64():
1187+
assert dp_B.dtype == np_B.dtype
1188+
else:
1189+
assert dp_B.dtype.kind == np_B.dtype.kind
1190+
1191+
assert dp_B.shape == np_B.shape
1192+
1193+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
1194+
@pytest.mark.parametrize(
1195+
"shape",
1196+
[(2, 2), (3, 4), (5, 3), (16, 16), (2, 2, 2), (2, 4, 2), (2, 2, 4)],
1197+
ids=[
1198+
"(2, 2)",
1199+
"(3, 4)",
1200+
"(5, 3)",
1201+
"(16, 16)",
1202+
"(2, 2, 2)",
1203+
"(2, 4, 2)",
1204+
"(2, 2, 4)",
1205+
],
1206+
)
1207+
def test_pinv(self, dtype, shape):
1208+
a = numpy.random.rand(*shape).astype(dtype)
1209+
a_dp = inp.array(a)
1210+
1211+
B = numpy.linalg.pinv(a)
1212+
B_dp = inp.linalg.pinv(a_dp)
1213+
1214+
self.check_types_shapes(B_dp, B)
1215+
self.get_tol(dtype)
1216+
tol = self._tol
1217+
assert_allclose(B_dp, B, rtol=tol, atol=tol)
1218+
1219+
if a.ndim == 2:
1220+
reconstructed = inp.dot(a_dp, inp.dot(B_dp, a_dp))
1221+
else: # a.ndim > 2
1222+
reconstructed = inp.matmul(a_dp, inp.matmul(B_dp, a_dp))
1223+
1224+
assert_allclose(reconstructed, a, rtol=tol, atol=tol)
1225+
1226+
@pytest.mark.parametrize("dtype", get_complex_dtypes())
1227+
@pytest.mark.parametrize(
1228+
"shape",
1229+
[(2, 2), (16, 16)],
1230+
ids=["(2,2)", "(16, 16)"],
1231+
)
1232+
def test_pinv_hermitian(self, dtype, shape):
1233+
a = numpy.random.randn(*shape) + 1j * numpy.random.randn(*shape)
1234+
a = numpy.conj(a.T) @ a
1235+
1236+
a = a.astype(dtype)
1237+
a_dp = inp.array(a)
1238+
1239+
B = numpy.linalg.pinv(a)
1240+
B_dp = inp.linalg.pinv(a_dp)
1241+
1242+
self.check_types_shapes(B_dp, B)
1243+
self.get_tol(dtype)
1244+
tol = self._tol
1245+
assert_allclose(B_dp, B, rtol=tol, atol=tol)
1246+
1247+
reconstructed = inp.dot(a_dp, inp.dot(B_dp, a_dp))
1248+
assert_allclose(reconstructed, a, rtol=tol, atol=1e-03)
1249+
1250+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
1251+
@pytest.mark.parametrize(
1252+
"shape",
1253+
[(0, 0), (0, 2), (2, 0), (2, 0, 3), (2, 3, 0), (0, 2, 3)],
1254+
ids=[
1255+
"(0, 0)",
1256+
"(0, 2)",
1257+
"(2 ,0)",
1258+
"(2, 0, 3)",
1259+
"(2, 3, 0)",
1260+
"(0, 2, 3)",
1261+
],
1262+
)
1263+
def test_pinv_empty(self, dtype, shape):
1264+
a = numpy.empty(shape, dtype=dtype)
1265+
a_dp = inp.array(a)
1266+
1267+
B = numpy.linalg.pinv(a)
1268+
B_dp = inp.linalg.pinv(a_dp)
1269+
1270+
assert_dtype_allclose(B_dp, B)
1271+
1272+
def test_pinv_strides(self):
1273+
a = numpy.random.rand(5, 5)
1274+
a_dp = inp.array(a)
1275+
1276+
self.get_tol(a_dp.dtype)
1277+
tol = self._tol
1278+
1279+
# positive strides
1280+
B = numpy.linalg.pinv(a[::2, ::2])
1281+
B_dp = inp.linalg.pinv(a_dp[::2, ::2])
1282+
assert_allclose(B_dp, B, rtol=tol, atol=tol)
1283+
1284+
# negative strides
1285+
B = numpy.linalg.pinv(a[::-2, ::-2])
1286+
B_dp = inp.linalg.pinv(a_dp[::-2, ::-2])
1287+
assert_allclose(B_dp, B, rtol=tol, atol=tol)
1288+
1289+
def test_pinv_errors(self):
1290+
a_dp = inp.array([[1, 2], [3, 4]], dtype="float32")
1291+
1292+
# unsupported type
1293+
a_np = inp.asnumpy(a_dp)
1294+
assert_raises(TypeError, inp.linalg.pinv, a_np)
1295+
1296+
# a.ndim < 2
1297+
a_dp_ndim_1 = a_dp.flatten()
1298+
assert_raises(inp.linalg.LinAlgError, inp.linalg.pinv, a_dp_ndim_1)

tests/test_sycl_queue.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,3 +1665,45 @@ def test_slogdet(shape, is_empty, device):
16651665

16661666
assert_sycl_queue_equal(sign_queue, dpnp_x.sycl_queue)
16671667
assert_sycl_queue_equal(logdet_queue, dpnp_x.sycl_queue)
1668+
1669+
1670+
@pytest.mark.parametrize(
1671+
"shape, hermitian",
1672+
[
1673+
((4, 4), False),
1674+
((2, 0), False),
1675+
((4, 4), True),
1676+
((2, 2, 3), False),
1677+
((0, 2, 3), False),
1678+
((1, 0, 3), False),
1679+
],
1680+
ids=[
1681+
"(4, 4)",
1682+
"(2, 0)",
1683+
"(2, 2), hermitian)",
1684+
"(2, 2, 3)",
1685+
"(0, 2, 3)",
1686+
"(1, 0, 3)",
1687+
],
1688+
)
1689+
@pytest.mark.parametrize(
1690+
"device",
1691+
valid_devices,
1692+
ids=[device.filter_string for device in valid_devices],
1693+
)
1694+
def test_pinv(shape, hermitian, device):
1695+
if hermitian:
1696+
a_np = numpy.random.randn(*shape) + 1j * numpy.random.randn(*shape)
1697+
a_np = numpy.conj(a_np.T) @ a_np
1698+
else:
1699+
a_np = numpy.random.randn(*shape)
1700+
1701+
a_dp = dpnp.array(a_np, device=device)
1702+
1703+
B_result = dpnp.linalg.pinv(a_dp, hermitian=hermitian)
1704+
B_expected = numpy.linalg.pinv(a_np, hermitian=hermitian)
1705+
assert_allclose(B_expected, B_result, rtol=1e-3, atol=1e-4)
1706+
1707+
B_queue = B_result.sycl_queue
1708+
1709+
assert_sycl_queue_equal(B_queue, a_dp.sycl_queue)

tests/test_usm_type.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,39 @@ def test_svd(usm_type, shape, full_matrices_param, compute_uv_param):
829829
assert x.usm_type == s.usm_type
830830

831831

832+
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
833+
@pytest.mark.parametrize(
834+
"shape, hermitian",
835+
[
836+
((4, 4), False),
837+
((2, 0), False),
838+
((4, 4), True),
839+
((2, 2, 3), False),
840+
((0, 2, 3), False),
841+
((1, 0, 3), False),
842+
],
843+
ids=[
844+
"(4, 4)",
845+
"(2, 0)",
846+
"(2, 2), hermitian)",
847+
"(2, 2, 3)",
848+
"(0, 2, 3)",
849+
"(1, 0, 3)",
850+
],
851+
)
852+
def test_pinv(shape, hermitian, usm_type):
853+
if hermitian:
854+
a = dp.random.randn(*shape) + 1j * dp.random.randn(*shape)
855+
a = dp.conj(a.T) @ a
856+
else:
857+
a = dp.random.randn(*shape)
858+
859+
a = dp.array(a, usm_type=usm_type)
860+
B = dp.lialg.pinv(a, hermitian=hermitian)
861+
862+
assert a.usm_type == B.usm_type
863+
864+
832865
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
833866
@pytest.mark.parametrize(
834867
"shape",
@@ -852,7 +885,7 @@ def test_svd(usm_type, shape, full_matrices_param, compute_uv_param):
852885
["r", "raw", "complete", "reduced"],
853886
ids=["r", "raw", "complete", "reduced"],
854887
)
855-
def test_qr(shape, mode, usm_type):
888+
def test_pinv(shape, mode, usm_type):
856889
count_elems = numpy.prod(shape)
857890
a = dp.arange(count_elems, usm_type=usm_type).reshape(shape)
858891

tests/third_party/cupy/linalg_tests/test_solve.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,45 @@ def test_batched_inv(self, dtype):
166166
assert a.ndim >= 3 # CuPy internally uses a batched function.
167167
with pytest.raises(xp.linalg.LinAlgError):
168168
xp.linalg.inv(a)
169+
170+
171+
class TestPinv(unittest.TestCase):
172+
@testing.for_dtypes("ifdFD")
173+
@_condition.retry(10)
174+
def check_x(self, a_shape, rcond, dtype):
175+
a_gpu = testing.shaped_random(a_shape, dtype=dtype)
176+
a_cpu = cupy.asnumpy(a_gpu)
177+
a_gpu_copy = a_gpu.copy()
178+
if not isinstance(rcond, float):
179+
rcond = numpy.asarray(rcond)
180+
result_cpu = numpy.linalg.pinv(a_cpu, rcond=rcond)
181+
if not isinstance(rcond, float):
182+
rcond = cupy.asarray(rcond)
183+
result_gpu = cupy.linalg.pinv(a_gpu, rcond=rcond)
184+
185+
assert_dtype_allclose(result_gpu, result_cpu)
186+
testing.assert_array_equal(a_gpu_copy, a_gpu)
187+
188+
def test_pinv(self):
189+
self.check_x((3, 3), rcond=1e-15)
190+
self.check_x((2, 4), rcond=1e-15)
191+
self.check_x((3, 2), rcond=1e-15)
192+
193+
self.check_x((4, 4), rcond=0.3)
194+
self.check_x((2, 5), rcond=0.5)
195+
self.check_x((5, 3), rcond=0.6)
196+
197+
def test_pinv_batched(self):
198+
self.check_x((2, 3, 4), rcond=1e-15)
199+
self.check_x((2, 3, 4, 5), rcond=1e-15)
200+
201+
def test_pinv_batched_vector_rcond(self):
202+
self.check_x((2, 3, 4), rcond=[0.2, 0.8])
203+
self.check_x((2, 3, 4, 5), rcond=[[0.2, 0.9, 0.1], [0.7, 0.2, 0.5]])
204+
205+
def test_pinv_size_0(self):
206+
self.check_x((3, 0), rcond=1e-15)
207+
self.check_x((0, 3), rcond=1e-15)
208+
self.check_x((0, 0), rcond=1e-15)
209+
self.check_x((0, 2, 3), rcond=1e-15)
210+
self.check_x((2, 0, 3), rcond=1e-15)

0 commit comments

Comments
 (0)