@@ -479,7 +479,7 @@ def dpnp_solve(a, b):
479
479
return b_f
480
480
481
481
482
- def _dpnp_svd_batch (a , uv_type , s_type , full_matrices = True , compute_uv = True ):
482
+ def dpnp_svd_batch (a , uv_type , s_type , full_matrices = True , compute_uv = True ):
483
483
a_usm_type = a .usm_type
484
484
a_sycl_queue = a .sycl_queue
485
485
reshape = False
@@ -490,8 +490,7 @@ def _dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
490
490
a = a .reshape (prod (a .shape [:- 2 ]), a .shape [- 2 ], a .shape [- 1 ])
491
491
reshape = True
492
492
493
- batch_shape = a .shape [:- 2 ]
494
- batch_size = prod (batch_shape )
493
+ batch_size = a .shape [0 ]
495
494
n , m = a .shape [- 2 :]
496
495
497
496
if batch_size == 0 :
@@ -504,31 +503,24 @@ def _dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
504
503
)
505
504
if compute_uv :
506
505
if full_matrices :
507
- u = dpnp .empty (
508
- batch_shape_orig + (n , n ),
509
- dtype = uv_type ,
510
- usm_type = a_usm_type ,
511
- sycl_queue = a_sycl_queue ,
512
- )
513
- vt = dpnp .empty (
514
- batch_shape_orig + (m , m ),
515
- dtype = uv_type ,
516
- usm_type = a_usm_type ,
517
- sycl_queue = a_sycl_queue ,
518
- )
506
+ u_shape = batch_shape_orig + (n , n )
507
+ vt_shape = batch_shape_orig + (m , m )
519
508
else :
520
- u = dpnp .empty (
521
- batch_shape_orig + (n , k ),
522
- dtype = uv_type ,
523
- usm_type = a_usm_type ,
524
- sycl_queue = a_sycl_queue ,
525
- )
526
- vt = dpnp .empty (
527
- batch_shape_orig + (k , m ),
528
- dtype = uv_type ,
529
- usm_type = a_usm_type ,
530
- sycl_queue = a_sycl_queue ,
531
- )
509
+ u_shape = batch_shape_orig + (n , k )
510
+ vt_shape = batch_shape_orig + (k , m )
511
+
512
+ u = dpnp .empty (
513
+ u_shape ,
514
+ dtype = uv_type ,
515
+ usm_type = a_usm_type ,
516
+ sycl_queue = a_sycl_queue ,
517
+ )
518
+ vt = dpnp .empty (
519
+ vt_shape ,
520
+ dtype = uv_type ,
521
+ usm_type = a_usm_type ,
522
+ sycl_queue = a_sycl_queue ,
523
+ )
532
524
return u , s , vt
533
525
else :
534
526
return s
@@ -617,7 +609,7 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True):
617
609
s_type = uv_type .char .lower ()
618
610
619
611
if a .ndim > 2 :
620
- return _dpnp_svd_batch (a , uv_type , s_type , full_matrices , compute_uv )
612
+ return dpnp_svd_batch (a , uv_type , s_type , full_matrices , compute_uv )
621
613
622
614
else :
623
615
n , m = a .shape
0 commit comments