Skip to content

Commit 33c4f5e

Browse files
Update tests for dpnp.linalg.svd
1 parent 61d1ed7 commit 33c4f5e

File tree

2 files changed

+51
-41
lines changed

2 files changed

+51
-41
lines changed

tests/test_linalg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ def test_svd(self, dtype, shape):
582582
assert dpnp_u.dtype == np_u.dtype
583583
assert dpnp_s.dtype == np_s.dtype
584584
assert dpnp_vt.dtype == np_vt.dtype
585+
585586
assert dpnp_u.shape == np_u.shape
586587
assert dpnp_s.shape == np_s.shape
587588
assert dpnp_vt.shape == np_vt.shape

tests/test_sycl_queue.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,53 +1115,62 @@ def test_qr(device):
11151115
valid_devices,
11161116
ids=[device.filter_string for device in valid_devices],
11171117
)
1118-
def test_svd(device):
1119-
shape = (2, 2)
1118+
@pytest.mark.parametrize("full_matrices", [True, False], ids=["True", "False"])
1119+
@pytest.mark.parametrize("compute_uv", [True, False], ids=["True", "False"])
1120+
@pytest.mark.parametrize(
1121+
"shape",
1122+
[
1123+
(1, 4),
1124+
(3, 2),
1125+
(4, 4),
1126+
(2, 0),
1127+
(0, 2),
1128+
(2, 2, 3),
1129+
(3, 3, 0),
1130+
(0, 2, 3),
1131+
(1, 0, 3),
1132+
],
1133+
ids=[
1134+
"(1, 4)",
1135+
"(3, 2)",
1136+
"(4, 4)",
1137+
"(2, 0)",
1138+
"(0, 2)",
1139+
"(2, 2, 3)",
1140+
"(3, 3, 0)",
1141+
"(0, 2, 3)",
1142+
"(1, 0, 3)",
1143+
],
1144+
)
1145+
def test_svd(shape, full_matrices, compute_uv, device):
11201146
dtype = dpnp.default_float_type(device)
1121-
numpy_data = numpy.arange(shape[0] * shape[1], dtype=dtype).reshape(shape)
1122-
dpnp_data = dpnp.arange(
1123-
shape[0] * shape[1], dtype=dtype, device=device
1124-
).reshape(shape)
1125-
1126-
np_u, np_s, np_vt = numpy.linalg.svd(numpy_data)
1127-
dpnp_u, dpnp_s, dpnp_vt = dpnp.linalg.svd(dpnp_data)
1128-
1129-
assert dpnp_u.dtype == np_u.dtype
1130-
assert dpnp_s.dtype == np_s.dtype
1131-
assert dpnp_vt.dtype == np_vt.dtype
1132-
assert dpnp_u.shape == np_u.shape
1133-
assert dpnp_s.shape == np_s.shape
1134-
assert dpnp_vt.shape == np_vt.shape
1135-
1136-
# check decomposition
1137-
dpnp_diag_s = dpnp.zeros(shape, dtype=dpnp_s.dtype, device=device)
1138-
for i in range(dpnp_s.size):
1139-
dpnp_diag_s[i, i] = dpnp_s[i]
1140-
1141-
# check decomposition
1142-
assert_dtype_allclose(
1143-
dpnp_data, dpnp.dot(dpnp_u, dpnp.dot(dpnp_diag_s, dpnp_vt))
1147+
1148+
count_elems = numpy.prod(shape)
1149+
dpnp_data = dpnp.arange(count_elems, dtype=dtype, device=device).reshape(
1150+
shape
11441151
)
1152+
expected_queue = dpnp_data.get_array().sycl_queue
11451153

1146-
for i in range(min(shape[0], shape[1])):
1147-
if np_u[0, i] * dpnp_u[0, i] < 0:
1148-
np_u[:, i] = -np_u[:, i]
1149-
np_vt[i, :] = -np_vt[i, :]
1154+
if compute_uv:
1155+
dpnp_u, dpnp_s, dpnp_vt = dpnp.linalg.svd(
1156+
dpnp_data, full_matrices=full_matrices, compute_uv=compute_uv
1157+
)
11501158

1151-
# compare vectors for non-zero values
1152-
for i in range(numpy.count_nonzero(np_s)):
1153-
assert_dtype_allclose(dpnp_u[:, i], np_u[:, i])
1154-
assert_dtype_allclose(dpnp_vt[i, :], np_vt[i, :])
1159+
dpnp_u_queue = dpnp_u.get_array().sycl_queue
1160+
dpnp_vt_queue = dpnp_vt.get_array().sycl_queue
1161+
dpnp_s_queue = dpnp_s.get_array().sycl_queue
11551162

1156-
expected_queue = dpnp_data.get_array().sycl_queue
1157-
dpnp_u_queue = dpnp_u.get_array().sycl_queue
1158-
dpnp_s_queue = dpnp_s.get_array().sycl_queue
1159-
dpnp_vt_queue = dpnp_vt.get_array().sycl_queue
1163+
assert_sycl_queue_equal(dpnp_u_queue, expected_queue)
1164+
assert_sycl_queue_equal(dpnp_vt_queue, expected_queue)
1165+
assert_sycl_queue_equal(dpnp_s_queue, expected_queue)
11601166

1161-
# compare queue and device
1162-
assert_sycl_queue_equal(dpnp_u_queue, expected_queue)
1163-
assert_sycl_queue_equal(dpnp_s_queue, expected_queue)
1164-
assert_sycl_queue_equal(dpnp_vt_queue, expected_queue)
1167+
else:
1168+
dpnp_s = dpnp.linalg.svd(
1169+
dpnp_data, full_matrices=full_matrices, compute_uv=compute_uv
1170+
)
1171+
dpnp_s_queue = dpnp_s.get_array().sycl_queue
1172+
1173+
assert_sycl_queue_equal(dpnp_s_queue, expected_queue)
11651174

11661175

11671176
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)