Skip to content

Commit e44469c

Browse files
Update dpnp.linalg.matrix_power() implementation (#1748)
* Add an implementation of dpnp.linalg.matrix_power * Update cupy tests for matrix_power * Add dpnp tests for matrix_power * Use add no_bool in tests to avoid singilar input matrix * Address remarks * Improve performance for _stacked_identity functions * Add TestMatrixPowerBatched to cupy tests * Update dpnp tests for matrix_power * Efficient use of binary decomposition --------- Co-authored-by: Anton <[email protected]>
1 parent 1b58244 commit e44469c

File tree

8 files changed

+236
-29
lines changed

8 files changed

+236
-29
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
dpnp_det,
5252
dpnp_eigh,
5353
dpnp_inv,
54+
dpnp_matrix_power,
5455
dpnp_matrix_rank,
5556
dpnp_multi_dot,
5657
dpnp_pinv,
@@ -370,33 +371,65 @@ def inv(a):
370371
return dpnp_inv(a)
371372

372373

373-
def matrix_power(input, count):
374+
def matrix_power(a, n):
374375
"""
375-
Raise a square matrix to the (integer) power `count`.
376+
Raise a square matrix to the (integer) power `n`.
377+
378+
For full documentation refer to :obj:`numpy.linalg.matrix_power`.
376379
377380
Parameters
378381
----------
379-
input : sequence of array_like
382+
a : (..., M, M) {dpnp.ndarray, usm_ndarray}
383+
Matrix to be "powered".
384+
n : int
385+
The exponent can be any integer or long integer, positive, negative, or zero.
380386
381387
Returns
382388
-------
383-
output : array
384-
Returns the dot product of the supplied arrays.
389+
a**n : (..., M, M) dpnp.ndarray
390+
The return value is the same shape and type as `M`;
391+
if the exponent is positive or zero then the type of the
392+
elements is the same as those of `M`. If the exponent is
393+
negative the elements are floating-point.
385394
386-
See Also
387-
--------
388-
:obj:`numpy.linalg.matrix_power`
395+
>>> import dpnp as np
396+
>>> i = np.array([[0, 1], [-1, 0]]) # matrix equiv. of the imaginary unit
397+
>>> np.linalg.matrix_power(i, 3) # should = -i
398+
array([[ 0, -1],
399+
[ 1, 0]])
400+
>>> np.linalg.matrix_power(i, 0)
401+
array([[1, 0],
402+
[0, 1]])
403+
>>> np.linalg.matrix_power(i, -3) # should = 1/(-i) = i, but w/ f.p. elements
404+
array([[ 0., 1.],
405+
[-1., 0.]])
406+
407+
Somewhat more sophisticated example
408+
409+
>>> q = np.zeros((4, 4))
410+
>>> q[0:2, 0:2] = -i
411+
>>> q[2:4, 2:4] = i
412+
>>> q # one of the three quaternion units not equal to 1
413+
array([[ 0., -1., 0., 0.],
414+
[ 1., 0., 0., 0.],
415+
[ 0., 0., 0., 1.],
416+
[ 0., 0., -1., 0.]])
417+
>>> np.linalg.matrix_power(q, 2) # = -np.eye(4)
418+
array([[-1., 0., 0., 0.],
419+
[ 0., -1., 0., 0.],
420+
[ 0., 0., -1., 0.],
421+
[ 0., 0., 0., -1.]])
389422
390423
"""
391424

392-
if not use_origin_backend() and count > 0:
393-
result = input
394-
for _ in range(count - 1):
395-
result = dpnp.matmul(result, input)
425+
dpnp.check_supported_arrays_type(a)
426+
check_stacked_2d(a)
427+
check_stacked_square(a)
396428

397-
return result
429+
if not isinstance(n, int):
430+
raise TypeError("exponent must be an integer")
398431

399-
return call_origin(numpy.linalg.matrix_power, input, count)
432+
return dpnp_matrix_power(a, n)
400433

401434

402435
def matrix_rank(A, tol=None, hermitian=False):

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"dpnp_det",
4141
"dpnp_eigh",
4242
"dpnp_inv",
43+
"dpnp_matrix_power",
4344
"dpnp_matrix_rank",
4445
"dpnp_multi_dot",
4546
"dpnp_pinv",
@@ -526,9 +527,50 @@ def _stacked_identity(
526527
"""
527528

528529
shape = batch_shape + (n, n)
529-
idx = dpnp.arange(n, usm_type=usm_type, sycl_queue=sycl_queue)
530-
x = dpnp.zeros(shape, dtype=dtype, usm_type=usm_type, sycl_queue=sycl_queue)
531-
x[..., idx, idx] = 1
530+
x = dpnp.empty(shape, dtype=dtype, usm_type=usm_type, sycl_queue=sycl_queue)
531+
x[...] = dpnp.eye(
532+
n, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=x.sycl_queue
533+
)
534+
return x
535+
536+
537+
def _stacked_identity_like(x):
538+
"""
539+
Create stacked identity matrices based on the shape and properties of `x`.
540+
541+
Parameters
542+
----------
543+
x : dpnp.ndarray
544+
Input array based on whose properties (shape, data type, USM type and SYCL queue)
545+
the identity matrices will be created.
546+
547+
Returns
548+
-------
549+
out : dpnp.ndarray
550+
Array of stacked `n x n` identity matrices,
551+
where `n` is the size of the last dimension of `x`.
552+
The returned array has the same shape, data type, USM type
553+
and uses the same SYCL queue as `x`, if applicable.
554+
555+
Example
556+
-------
557+
>>> import dpnp
558+
>>> x = dpnp.zeros((2, 3, 3), dtype=dpnp.int64)
559+
>>> _stacked_identity_like(x)
560+
array([[[1, 0, 0],
561+
[0, 1, 0],
562+
[0, 0, 1]],
563+
564+
[[1, 0, 0],
565+
[0, 1, 0],
566+
[0, 0, 1]]], dtype=int32)
567+
568+
"""
569+
570+
x = dpnp.empty_like(x)
571+
x[...] = dpnp.eye(
572+
x.shape[-2], dtype=x.dtype, usm_type=x.usm_type, sycl_queue=x.sycl_queue
573+
)
532574
return x
533575

534576

@@ -1082,6 +1124,46 @@ def dpnp_inv(a):
10821124
return b_f
10831125

10841126

1127+
def dpnp_matrix_power(a, n):
1128+
"""
1129+
dpnp_matrix_power(a, n)
1130+
1131+
Raise a square matrix to the (integer) power `n`.
1132+
1133+
"""
1134+
1135+
if n == 0:
1136+
return _stacked_identity_like(a)
1137+
elif n < 0:
1138+
a = dpnp.linalg.inv(a)
1139+
n *= -1
1140+
1141+
if n == 1:
1142+
return a
1143+
elif n == 2:
1144+
return dpnp.matmul(a, a)
1145+
elif n == 3:
1146+
return dpnp.matmul(dpnp.matmul(a, a), a)
1147+
1148+
# Use binary decomposition to reduce the number of matrix
1149+
# multiplications for n > 3.
1150+
# `result` will hold the final matrix power,
1151+
# while `acc` serves as an accumulator for the intermediate matrix powers.
1152+
result = None
1153+
acc = a.copy()
1154+
while n > 0:
1155+
n, bit = divmod(n, 2)
1156+
if bit:
1157+
if result is None:
1158+
result = acc.copy()
1159+
else:
1160+
dpnp.matmul(result, acc, out=result)
1161+
if n > 0:
1162+
dpnp.matmul(acc, acc, out=acc)
1163+
1164+
return result
1165+
1166+
10851167
def dpnp_matrix_rank(A, tol=None, hermitian=False):
10861168
"""
10871169
dpnp_matrix_rank(A, tol=None, hermitian=False)

tests/skipped_tests.tbl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test
332332
tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_invalid_sub1
333333
tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_too_many_dims3
334334

335-
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge
336-
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large
337-
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two
338-
339335
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_broadcast_not_allowed
340336
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_is_equal
341337
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_not_equal

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -434,10 +434,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWith
434434
tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_int
435435
tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_invalid_sub1
436436

437-
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge
438-
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large
439-
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two
440-
441437
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_broadcast_not_allowed
442438
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_is_equal
443439
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_not_equal

tests/test_linalg.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,56 @@ def test_inv_errors(self):
566566
assert_raises(inp.linalg.LinAlgError, inp.linalg.inv, a_dp)
567567

568568

569+
class TestMatrixPower:
570+
@pytest.mark.parametrize("dtype", get_all_dtypes())
571+
@pytest.mark.parametrize(
572+
"data, power",
573+
[
574+
(
575+
numpy.block(
576+
[
577+
[numpy.eye(2), numpy.zeros((2, 2))],
578+
[numpy.zeros((2, 2)), numpy.eye(2) * 2],
579+
]
580+
),
581+
3,
582+
), # Block-diagonal matrix
583+
(numpy.eye(3, k=1) + numpy.eye(3), 3), # Non-diagonal matrix
584+
(
585+
numpy.eye(3, k=1) + numpy.eye(3),
586+
-3,
587+
), # Inverse of non-diagonal matrix
588+
],
589+
)
590+
def test_matrix_power(self, data, power, dtype):
591+
a = data.astype(dtype)
592+
a_dp = inp.array(a)
593+
594+
result = inp.linalg.matrix_power(a_dp, power)
595+
expected = numpy.linalg.matrix_power(a, power)
596+
597+
assert_dtype_allclose(result, expected)
598+
599+
def test_matrix_power_errors(self):
600+
a_dp = inp.eye(4, dtype="float32")
601+
602+
# unsupported type `a`
603+
a_np = inp.asnumpy(a_dp)
604+
assert_raises(TypeError, inp.linalg.matrix_power, a_np, 2)
605+
606+
# unsupported type `power`
607+
assert_raises(TypeError, inp.linalg.matrix_power, a_dp, 1.5)
608+
assert_raises(TypeError, inp.linalg.matrix_power, a_dp, [2])
609+
610+
# not invertible
611+
# TODO: remove it when mkl>=2024.0 is released (MKLD-16626)
612+
if not is_cpu_device():
613+
noninv = inp.array([[1, 0], [0, 0]])
614+
assert_raises(
615+
inp.linalg.LinAlgError, inp.linalg.matrix_power, noninv, -1
616+
)
617+
618+
569619
class TestMatrixRank:
570620
@pytest.mark.parametrize("dtype", get_all_dtypes())
571621
@pytest.mark.parametrize(

tests/test_sycl_queue.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,30 @@ def test_inv(shape, is_empty, device):
12991299
assert_sycl_queue_equal(result_queue, expected_queue)
13001300

13011301

1302+
@pytest.mark.parametrize(
1303+
"n",
1304+
[-1, 0, 1, 2, 3],
1305+
ids=["-1", "0", "1", "2", "3"],
1306+
)
1307+
@pytest.mark.parametrize(
1308+
"device",
1309+
valid_devices,
1310+
ids=[device.filter_string for device in valid_devices],
1311+
)
1312+
def test_matrix_power(n, device):
1313+
data = numpy.array([[1, 2], [3, 5]], dtype=dpnp.default_float_type(device))
1314+
dp_data = dpnp.array(data, device=device)
1315+
1316+
result = dpnp.linalg.matrix_power(dp_data, n)
1317+
expected = numpy.linalg.matrix_power(data, n)
1318+
assert_dtype_allclose(result, expected)
1319+
1320+
expected_queue = dp_data.get_array().sycl_queue
1321+
result_queue = result.get_array().sycl_queue
1322+
1323+
assert_sycl_queue_equal(result_queue, expected_queue)
1324+
1325+
13021326
@pytest.mark.parametrize(
13031327
"data, tol",
13041328
[

tests/test_usm_type.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,19 @@ def test_svd(usm_type, shape, full_matrices_param, compute_uv_param):
924924
assert x.usm_type == s.usm_type
925925

926926

927+
@pytest.mark.parametrize(
928+
"n",
929+
[-1, 0, 1, 2, 3],
930+
ids=["-1", "0", "1", "2", "3"],
931+
)
932+
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
933+
def test_matrix_power(n, usm_type):
934+
a = dp.array([[1, 2], [3, 5]], usm_type=usm_type)
935+
936+
dp_res = dp.linalg.matrix_power(a, n)
937+
assert a.usm_type == dp_res.usm_type
938+
939+
927940
@pytest.mark.parametrize(
928941
"data, tol",
929942
[

tests/third_party/cupy/linalg_tests/test_product.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,6 @@ def test_tensordot_zero_length(self, xp, dtype):
430430

431431

432432
class TestMatrixPower(unittest.TestCase):
433-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
434433
@testing.for_all_dtypes()
435434
@testing.numpy_cupy_allclose()
436435
def test_matrix_power_0(self, xp, dtype):
@@ -455,23 +454,20 @@ def test_matrix_power_3(self, xp, dtype):
455454
a = testing.shaped_arange((3, 3), xp, dtype)
456455
return xp.linalg.matrix_power(a, 3)
457456

458-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
459457
@testing.for_float_dtypes(no_float16=True)
460458
@testing.numpy_cupy_allclose(rtol=1e-5)
461459
def test_matrix_power_inv1(self, xp, dtype):
462460
a = testing.shaped_arange((3, 3), xp, dtype)
463461
a = a * a % 30
464462
return xp.linalg.matrix_power(a, -1)
465463

466-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
467464
@testing.for_float_dtypes(no_float16=True)
468465
@testing.numpy_cupy_allclose(rtol=1e-5)
469466
def test_matrix_power_inv2(self, xp, dtype):
470467
a = testing.shaped_arange((3, 3), xp, dtype)
471468
a = a * a % 30
472469
return xp.linalg.matrix_power(a, -2)
473470

474-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
475471
@testing.for_float_dtypes(no_float16=True)
476472
@testing.numpy_cupy_allclose(rtol=1e-4)
477473
def test_matrix_power_inv3(self, xp, dtype):
@@ -496,3 +492,20 @@ def test_matrix_power_large(self, xp, dtype):
496492
def test_matrix_power_invlarge(self, xp, dtype):
497493
a = xp.eye(23, k=17, dtype=dtype) + xp.eye(23, k=-6, dtype=dtype)
498494
return xp.linalg.matrix_power(a, -987654321987654321)
495+
496+
497+
@pytest.mark.parametrize(
498+
"shape",
499+
[
500+
(2, 3, 3),
501+
(3, 0, 0),
502+
],
503+
)
504+
@pytest.mark.parametrize("n", [0, 5, -7])
505+
class TestMatrixPowerBatched:
506+
@testing.for_float_dtypes(no_float16=True)
507+
@testing.numpy_cupy_allclose(rtol=5e-5)
508+
def test_matrix_power_batched(self, xp, dtype, shape, n):
509+
a = testing.shaped_arange(shape, xp, dtype)
510+
a += xp.identity(shape[-1], dtype)
511+
return xp.linalg.matrix_power(a, n)

0 commit comments

Comments
 (0)