Skip to content

Commit e43b9a3

Browse files
Remove all TODOs
1 parent c08959d commit e43b9a3

File tree

4 files changed

+36
-68
lines changed

4 files changed

+36
-68
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -642,18 +642,10 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
642642
643643
"""
644644

645-
if not dpnp.is_supported_array_type(a):
646-
raise TypeError(
647-
"An array must be any of supported type, but got {}".format(type(a))
648-
)
645+
dpnp.check_supported_arrays_type(a)
646+
check_stacked_2d(a)
649647

650648
if hermitian is True:
651649
raise ValueError("The hermitian argument is only supported as False")
652650

653-
# TODO: use dpnp.linalg.LinAlgError
654-
if a.ndim < 2:
655-
raise ValueError(
656-
f"{a.ndim}-dimensional array given. Array must be at least two-dimensional"
657-
)
658-
659651
return dpnp_svd(a, full_matrices, compute_uv)

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 29 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,7 @@ def dpnp_solve(a, b):
443443
return b_f
444444

445445

446-
def _dpnp_svd_batch(
447-
a, res_type, res_type_s, full_matrices=True, compute_uv=True
448-
):
446+
def _dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
449447
a_usm_type = a.usm_type
450448
a_sycl_queue = a.sycl_queue
451449
reshape = False
@@ -464,34 +462,34 @@ def _dpnp_svd_batch(
464462
k = min(m, n)
465463
s = dpnp.empty(
466464
batch_shape_orig + (k,),
467-
dtype=res_type_s,
465+
dtype=s_type,
468466
usm_type=a_usm_type,
469467
sycl_queue=a_sycl_queue,
470468
)
471469
if compute_uv:
472470
if full_matrices:
473471
u = dpnp.empty(
474472
batch_shape_orig + (n, n),
475-
dtype=res_type,
473+
dtype=uv_type,
476474
usm_type=a_usm_type,
477475
sycl_queue=a_sycl_queue,
478476
)
479477
vt = dpnp.empty(
480478
batch_shape_orig + (m, m),
481-
dtype=res_type,
479+
dtype=uv_type,
482480
usm_type=a_usm_type,
483481
sycl_queue=a_sycl_queue,
484482
)
485483
else:
486484
u = dpnp.empty(
487485
batch_shape_orig + (n, k),
488-
dtype=res_type,
486+
dtype=uv_type,
489487
usm_type=a_usm_type,
490488
sycl_queue=a_sycl_queue,
491489
)
492490
vt = dpnp.empty(
493491
batch_shape_orig + (k, m),
494-
dtype=res_type,
492+
dtype=uv_type,
495493
usm_type=a_usm_type,
496494
sycl_queue=a_sycl_queue,
497495
)
@@ -501,7 +499,7 @@ def _dpnp_svd_batch(
501499
elif m == 0 or n == 0:
502500
s = dpnp.empty(
503501
batch_shape_orig + (0,),
504-
dtype=res_type_s,
502+
dtype=s_type,
505503
usm_type=a_usm_type,
506504
sycl_queue=a_sycl_queue,
507505
)
@@ -510,27 +508,27 @@ def _dpnp_svd_batch(
510508
u = _stacked_identity(
511509
batch_shape_orig,
512510
n,
513-
dtype=res_type,
511+
dtype=uv_type,
514512
usm_type=a_usm_type,
515513
sycl_queue=a_sycl_queue,
516514
)
517515
vt = _stacked_identity(
518516
batch_shape_orig,
519517
m,
520-
dtype=res_type,
518+
dtype=uv_type,
521519
usm_type=a_usm_type,
522520
sycl_queue=a_sycl_queue,
523521
)
524522
else:
525523
u = dpnp.empty(
526524
batch_shape_orig + (n, 0),
527-
dtype=res_type,
525+
dtype=uv_type,
528526
usm_type=a_usm_type,
529527
sycl_queue=a_sycl_queue,
530528
)
531529
vt = dpnp.empty(
532530
batch_shape_orig + (0, m),
533-
dtype=res_type,
531+
dtype=uv_type,
534532
usm_type=a_usm_type,
535533
sycl_queue=a_sycl_queue,
536534
)
@@ -579,72 +577,52 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True):
579577
a_usm_type = a.usm_type
580578
a_sycl_queue = a.sycl_queue
581579

582-
# TODO: Use linalg_common_type from #1598
583-
if dpnp.issubdtype(a.dtype, dpnp.floating):
584-
res_type = (
585-
a.dtype
586-
if a_sycl_queue.sycl_device.has_aspect_fp64
587-
else dpnp.float32
588-
)
589-
elif dpnp.issubdtype(a.dtype, dpnp.complexfloating):
590-
res_type = (
591-
a.dtype
592-
if a_sycl_queue.sycl_device.has_aspect_fp64
593-
else dpnp.complex64
594-
)
595-
else:
596-
res_type = (
597-
dpnp.float64
598-
if a_sycl_queue.sycl_device.has_aspect_fp64
599-
else dpnp.float32
600-
)
580+
uv_type = _common_type(a)
601581

602-
res_type_s = (
582+
s_type = (
603583
dpnp.float64
604584
if a_sycl_queue.sycl_device.has_aspect_fp64
605-
and (res_type == dpnp.float64 or res_type == dpnp.complex128)
585+
and (uv_type == dpnp.float64 or uv_type == dpnp.complex128)
606586
else dpnp.float32
607587
)
608588

609589
if a.ndim > 2:
610-
return _dpnp_svd_batch(
611-
a, res_type, res_type_s, full_matrices, compute_uv
612-
)
590+
return _dpnp_svd_batch(a, uv_type, s_type, full_matrices, compute_uv)
613591

614592
else:
615593
n, m = a.shape
616594

617595
if m == 0 or n == 0:
618596
s = dpnp.empty(
619597
(0,),
620-
dtype=res_type_s,
598+
dtype=s_type,
621599
usm_type=a_usm_type,
622600
sycl_queue=a_sycl_queue,
623601
)
624602
if compute_uv:
625603
if full_matrices:
626604
u = dpnp.eye(
627605
n,
628-
dtype=res_type,
606+
dtype=uv_type,
629607
usm_type=a_usm_type,
630608
sycl_queue=a_sycl_queue,
631609
)
632610
vt = dpnp.eye(
633611
m,
634-
dtype=res_type,
612+
dtype=uv_type,
635613
usm_type=a_usm_type,
636614
sycl_queue=a_sycl_queue,
637615
)
638616
else:
639617
u = dpnp.empty(
640618
(n, 0),
641-
dtype=res_type,
619+
dtype=uv_type,
642620
usm_type=a_usm_type,
643621
sycl_queue=a_sycl_queue,
644622
)
645623
vt = dpnp.empty(
646624
(0, m),
647-
dtype=res_type,
625+
dtype=uv_type,
648626
usm_type=a_usm_type,
649627
sycl_queue=a_sycl_queue,
650628
)
@@ -656,12 +634,12 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True):
656634
# `a` must be traspotted if m < n
657635
if m >= n:
658636
x = a
659-
a_h = dpnp.empty_like(a, order="C", dtype=res_type)
637+
a_h = dpnp.empty_like(a, order="C", dtype=uv_type)
660638
trans_flag = False
661639
else:
662640
m, n = a.shape
663641
x = a.transpose()
664-
a_h = dpnp.empty_like(x, order="C", dtype=res_type)
642+
a_h = dpnp.empty_like(x, order="C", dtype=uv_type)
665643
trans_flag = True
666644

667645
a_usm_arr = dpnp.get_usm_ndarray(x)
@@ -677,23 +655,23 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True):
677655
if full_matrices:
678656
u_h = dpnp.empty(
679657
(m, m),
680-
dtype=res_type,
658+
dtype=uv_type,
681659
usm_type=a_usm_type,
682660
sycl_queue=a_sycl_queue,
683661
)
684662
vt_h = dpnp.empty(
685663
(n, n),
686-
dtype=res_type,
664+
dtype=uv_type,
687665
usm_type=a_usm_type,
688666
sycl_queue=a_sycl_queue,
689667
)
690668
jobu = ord("A")
691669
jobvt = ord("A")
692670
else:
693-
u_h = dpnp.empty_like(x, dtype=res_type)
671+
u_h = dpnp.empty_like(x, dtype=uv_type)
694672
vt_h = dpnp.empty(
695673
(k, n),
696-
dtype=res_type,
674+
dtype=uv_type,
697675
usm_type=a_usm_type,
698676
sycl_queue=a_sycl_queue,
699677
)
@@ -702,21 +680,21 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True):
702680
else:
703681
u_h = dpnp.empty(
704682
[],
705-
dtype=res_type,
683+
dtype=uv_type,
706684
usm_type=a_usm_type,
707685
sycl_queue=a_sycl_queue,
708686
)
709687
vt_h = dpnp.empty(
710688
[],
711-
dtype=res_type,
689+
dtype=uv_type,
712690
usm_type=a_usm_type,
713691
sycl_queue=a_sycl_queue,
714692
)
715693
jobu = ord("N")
716694
jobvt = ord("N")
717695

718696
s_h = dpnp.empty(
719-
k, dtype=res_type_s, usm_type=a_usm_type, sycl_queue=a_sycl_queue
697+
k, dtype=s_type, usm_type=a_usm_type, sycl_queue=a_sycl_queue
720698
)
721699

722700
ht_lapack_ev, _ = li._gesvd(

tests/test_linalg.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,5 @@ def test_svd_errors(self):
647647
assert_raises(ValueError, inp.linalg.svd, a_dp, hermitian=True)
648648

649649
# a.ndim < 2
650-
# TODO: use inp.linalg.LinAlgError
651650
a_dp_ndim_1 = a_dp.flatten()
652-
assert_raises(ValueError, inp.linalg.svd, a_dp_ndim_1)
651+
assert_raises(inp.linalg.LinAlgError, inp.linalg.svd, a_dp_ndim_1)

tests/third_party/cupy/linalg_tests/test_decomposition.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pytest
55

66
import dpnp as cupy
7-
from dpnp import random
87
from tests.helper import has_support_aspect64
98
from tests.third_party.cupy import testing
109
from tests.third_party.cupy.testing import condition
@@ -155,7 +154,7 @@ def test_svd_rank2_empty_array_compute_uv_false(self, xp):
155154
array, full_matrices=self.full_matrices, compute_uv=False
156155
)
157156

158-
# # @condition.repeat(3, 10)
157+
# @condition.repeat(3, 10)
159158
def test_svd_rank3(self):
160159
self.check_usv((2, 3, 4))
161160
self.check_usv((2, 3, 7))
@@ -220,22 +219,22 @@ def test_svd_rank4(self):
220219
self.check_usv((2, 2, 4, 3))
221220
self.check_usv((2, 2, 32, 32))
222221

223-
# # @condition.repeat(3, 10)
222+
# @condition.repeat(3, 10)
224223
def test_svd_rank4_loop(self):
225224
# This tests the loop-based batched gesvd on CUDA (_gesvd_batched)
226225
self.check_usv((3, 2, 64, 64))
227226
self.check_usv((3, 2, 64, 32))
228227
self.check_usv((3, 2, 32, 64))
229228

230-
# # @condition.repeat(3, 10)
229+
# @condition.repeat(3, 10)
231230
def test_svd_rank4_no_uv(self):
232231
self.check_singular((2, 2, 3, 4))
233232
self.check_singular((2, 2, 3, 7))
234233
self.check_singular((2, 2, 4, 4))
235234
self.check_singular((2, 2, 7, 3))
236235
self.check_singular((2, 2, 4, 3))
237236

238-
# # @condition.repeat(3, 10)
237+
# @condition.repeat(3, 10)
239238
def test_svd_rank4_no_uv_loop(self):
240239
# This tests the loop-based batched gesvd on CUDA (_gesvd_batched)
241240
self.check_singular((3, 2, 64, 64))

0 commit comments

Comments
 (0)