@@ -443,9 +443,7 @@ def dpnp_solve(a, b):
443
443
return b_f
444
444
445
445
446
- def _dpnp_svd_batch (
447
- a , res_type , res_type_s , full_matrices = True , compute_uv = True
448
- ):
446
+ def _dpnp_svd_batch (a , uv_type , s_type , full_matrices = True , compute_uv = True ):
449
447
a_usm_type = a .usm_type
450
448
a_sycl_queue = a .sycl_queue
451
449
reshape = False
@@ -464,34 +462,34 @@ def _dpnp_svd_batch(
464
462
k = min (m , n )
465
463
s = dpnp .empty (
466
464
batch_shape_orig + (k ,),
467
- dtype = res_type_s ,
465
+ dtype = s_type ,
468
466
usm_type = a_usm_type ,
469
467
sycl_queue = a_sycl_queue ,
470
468
)
471
469
if compute_uv :
472
470
if full_matrices :
473
471
u = dpnp .empty (
474
472
batch_shape_orig + (n , n ),
475
- dtype = res_type ,
473
+ dtype = uv_type ,
476
474
usm_type = a_usm_type ,
477
475
sycl_queue = a_sycl_queue ,
478
476
)
479
477
vt = dpnp .empty (
480
478
batch_shape_orig + (m , m ),
481
- dtype = res_type ,
479
+ dtype = uv_type ,
482
480
usm_type = a_usm_type ,
483
481
sycl_queue = a_sycl_queue ,
484
482
)
485
483
else :
486
484
u = dpnp .empty (
487
485
batch_shape_orig + (n , k ),
488
- dtype = res_type ,
486
+ dtype = uv_type ,
489
487
usm_type = a_usm_type ,
490
488
sycl_queue = a_sycl_queue ,
491
489
)
492
490
vt = dpnp .empty (
493
491
batch_shape_orig + (k , m ),
494
- dtype = res_type ,
492
+ dtype = uv_type ,
495
493
usm_type = a_usm_type ,
496
494
sycl_queue = a_sycl_queue ,
497
495
)
@@ -501,7 +499,7 @@ def _dpnp_svd_batch(
501
499
elif m == 0 or n == 0 :
502
500
s = dpnp .empty (
503
501
batch_shape_orig + (0 ,),
504
- dtype = res_type_s ,
502
+ dtype = s_type ,
505
503
usm_type = a_usm_type ,
506
504
sycl_queue = a_sycl_queue ,
507
505
)
@@ -510,27 +508,27 @@ def _dpnp_svd_batch(
510
508
u = _stacked_identity (
511
509
batch_shape_orig ,
512
510
n ,
513
- dtype = res_type ,
511
+ dtype = uv_type ,
514
512
usm_type = a_usm_type ,
515
513
sycl_queue = a_sycl_queue ,
516
514
)
517
515
vt = _stacked_identity (
518
516
batch_shape_orig ,
519
517
m ,
520
- dtype = res_type ,
518
+ dtype = uv_type ,
521
519
usm_type = a_usm_type ,
522
520
sycl_queue = a_sycl_queue ,
523
521
)
524
522
else :
525
523
u = dpnp .empty (
526
524
batch_shape_orig + (n , 0 ),
527
- dtype = res_type ,
525
+ dtype = uv_type ,
528
526
usm_type = a_usm_type ,
529
527
sycl_queue = a_sycl_queue ,
530
528
)
531
529
vt = dpnp .empty (
532
530
batch_shape_orig + (0 , m ),
533
- dtype = res_type ,
531
+ dtype = uv_type ,
534
532
usm_type = a_usm_type ,
535
533
sycl_queue = a_sycl_queue ,
536
534
)
@@ -579,72 +577,52 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True):
579
577
a_usm_type = a .usm_type
580
578
a_sycl_queue = a .sycl_queue
581
579
582
- # TODO: Use linalg_common_type from #1598
583
- if dpnp .issubdtype (a .dtype , dpnp .floating ):
584
- res_type = (
585
- a .dtype
586
- if a_sycl_queue .sycl_device .has_aspect_fp64
587
- else dpnp .float32
588
- )
589
- elif dpnp .issubdtype (a .dtype , dpnp .complexfloating ):
590
- res_type = (
591
- a .dtype
592
- if a_sycl_queue .sycl_device .has_aspect_fp64
593
- else dpnp .complex64
594
- )
595
- else :
596
- res_type = (
597
- dpnp .float64
598
- if a_sycl_queue .sycl_device .has_aspect_fp64
599
- else dpnp .float32
600
- )
580
+ uv_type = _common_type (a )
601
581
602
- res_type_s = (
582
+ s_type = (
603
583
dpnp .float64
604
584
if a_sycl_queue .sycl_device .has_aspect_fp64
605
- and (res_type == dpnp .float64 or res_type == dpnp .complex128 )
585
+ and (uv_type == dpnp .float64 or uv_type == dpnp .complex128 )
606
586
else dpnp .float32
607
587
)
608
588
609
589
if a .ndim > 2 :
610
- return _dpnp_svd_batch (
611
- a , res_type , res_type_s , full_matrices , compute_uv
612
- )
590
+ return _dpnp_svd_batch (a , uv_type , s_type , full_matrices , compute_uv )
613
591
614
592
else :
615
593
n , m = a .shape
616
594
617
595
if m == 0 or n == 0 :
618
596
s = dpnp .empty (
619
597
(0 ,),
620
- dtype = res_type_s ,
598
+ dtype = s_type ,
621
599
usm_type = a_usm_type ,
622
600
sycl_queue = a_sycl_queue ,
623
601
)
624
602
if compute_uv :
625
603
if full_matrices :
626
604
u = dpnp .eye (
627
605
n ,
628
- dtype = res_type ,
606
+ dtype = uv_type ,
629
607
usm_type = a_usm_type ,
630
608
sycl_queue = a_sycl_queue ,
631
609
)
632
610
vt = dpnp .eye (
633
611
m ,
634
- dtype = res_type ,
612
+ dtype = uv_type ,
635
613
usm_type = a_usm_type ,
636
614
sycl_queue = a_sycl_queue ,
637
615
)
638
616
else :
639
617
u = dpnp .empty (
640
618
(n , 0 ),
641
- dtype = res_type ,
619
+ dtype = uv_type ,
642
620
usm_type = a_usm_type ,
643
621
sycl_queue = a_sycl_queue ,
644
622
)
645
623
vt = dpnp .empty (
646
624
(0 , m ),
647
- dtype = res_type ,
625
+ dtype = uv_type ,
648
626
usm_type = a_usm_type ,
649
627
sycl_queue = a_sycl_queue ,
650
628
)
@@ -656,12 +634,12 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True):
656
634
# `a` must be traspotted if m < n
657
635
if m >= n :
658
636
x = a
659
- a_h = dpnp .empty_like (a , order = "C" , dtype = res_type )
637
+ a_h = dpnp .empty_like (a , order = "C" , dtype = uv_type )
660
638
trans_flag = False
661
639
else :
662
640
m , n = a .shape
663
641
x = a .transpose ()
664
- a_h = dpnp .empty_like (x , order = "C" , dtype = res_type )
642
+ a_h = dpnp .empty_like (x , order = "C" , dtype = uv_type )
665
643
trans_flag = True
666
644
667
645
a_usm_arr = dpnp .get_usm_ndarray (x )
@@ -677,23 +655,23 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True):
677
655
if full_matrices :
678
656
u_h = dpnp .empty (
679
657
(m , m ),
680
- dtype = res_type ,
658
+ dtype = uv_type ,
681
659
usm_type = a_usm_type ,
682
660
sycl_queue = a_sycl_queue ,
683
661
)
684
662
vt_h = dpnp .empty (
685
663
(n , n ),
686
- dtype = res_type ,
664
+ dtype = uv_type ,
687
665
usm_type = a_usm_type ,
688
666
sycl_queue = a_sycl_queue ,
689
667
)
690
668
jobu = ord ("A" )
691
669
jobvt = ord ("A" )
692
670
else :
693
- u_h = dpnp .empty_like (x , dtype = res_type )
671
+ u_h = dpnp .empty_like (x , dtype = uv_type )
694
672
vt_h = dpnp .empty (
695
673
(k , n ),
696
- dtype = res_type ,
674
+ dtype = uv_type ,
697
675
usm_type = a_usm_type ,
698
676
sycl_queue = a_sycl_queue ,
699
677
)
@@ -702,21 +680,21 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True):
702
680
else :
703
681
u_h = dpnp .empty (
704
682
[],
705
- dtype = res_type ,
683
+ dtype = uv_type ,
706
684
usm_type = a_usm_type ,
707
685
sycl_queue = a_sycl_queue ,
708
686
)
709
687
vt_h = dpnp .empty (
710
688
[],
711
- dtype = res_type ,
689
+ dtype = uv_type ,
712
690
usm_type = a_usm_type ,
713
691
sycl_queue = a_sycl_queue ,
714
692
)
715
693
jobu = ord ("N" )
716
694
jobvt = ord ("N" )
717
695
718
696
s_h = dpnp .empty (
719
- k , dtype = res_type_s , usm_type = a_usm_type , sycl_queue = a_sycl_queue
697
+ k , dtype = s_type , usm_type = a_usm_type , sycl_queue = a_sycl_queue
720
698
)
721
699
722
700
ht_lapack_ev , _ = li ._gesvd (
0 commit comments