Skip to content

Commit 61d1ed7

Browse files
Simplify dpnp_svd_batch
1 parent 61257e2 commit 61d1ed7

File tree

1 file changed

+20
-28
lines changed

1 file changed

+20
-28
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def dpnp_solve(a, b):
479479
return b_f
480480

481481

482-
def _dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
482+
def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
483483
a_usm_type = a.usm_type
484484
a_sycl_queue = a.sycl_queue
485485
reshape = False
@@ -490,8 +490,7 @@ def _dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
490490
a = a.reshape(prod(a.shape[:-2]), a.shape[-2], a.shape[-1])
491491
reshape = True
492492

493-
batch_shape = a.shape[:-2]
494-
batch_size = prod(batch_shape)
493+
batch_size = a.shape[0]
495494
n, m = a.shape[-2:]
496495

497496
if batch_size == 0:
@@ -504,31 +503,24 @@ def _dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
504503
)
505504
if compute_uv:
506505
if full_matrices:
507-
u = dpnp.empty(
508-
batch_shape_orig + (n, n),
509-
dtype=uv_type,
510-
usm_type=a_usm_type,
511-
sycl_queue=a_sycl_queue,
512-
)
513-
vt = dpnp.empty(
514-
batch_shape_orig + (m, m),
515-
dtype=uv_type,
516-
usm_type=a_usm_type,
517-
sycl_queue=a_sycl_queue,
518-
)
506+
u_shape = batch_shape_orig + (n, n)
507+
vt_shape = batch_shape_orig + (m, m)
519508
else:
520-
u = dpnp.empty(
521-
batch_shape_orig + (n, k),
522-
dtype=uv_type,
523-
usm_type=a_usm_type,
524-
sycl_queue=a_sycl_queue,
525-
)
526-
vt = dpnp.empty(
527-
batch_shape_orig + (k, m),
528-
dtype=uv_type,
529-
usm_type=a_usm_type,
530-
sycl_queue=a_sycl_queue,
531-
)
509+
u_shape = batch_shape_orig + (n, k)
510+
vt_shape = batch_shape_orig + (k, m)
511+
512+
u = dpnp.empty(
513+
u_shape,
514+
dtype=uv_type,
515+
usm_type=a_usm_type,
516+
sycl_queue=a_sycl_queue,
517+
)
518+
vt = dpnp.empty(
519+
vt_shape,
520+
dtype=uv_type,
521+
usm_type=a_usm_type,
522+
sycl_queue=a_sycl_queue,
523+
)
532524
return u, s, vt
533525
else:
534526
return s
@@ -617,7 +609,7 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True):
617609
s_type = uv_type.char.lower()
618610

619611
if a.ndim > 2:
620-
return _dpnp_svd_batch(a, uv_type, s_type, full_matrices, compute_uv)
612+
return dpnp_svd_batch(a, uv_type, s_type, full_matrices, compute_uv)
621613

622614
else:
623615
n, m = a.shape

0 commit comments

Comments
 (0)