Skip to content

Commit 0fe3346

Browse files
Add test_qr in test_usm_type
1 parent 71277b0 commit 0fe3346

File tree

2 files changed

+47
-12
lines changed

2 files changed

+47
-12
lines changed

tests/test_sycl_queue.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,24 +1201,22 @@ def test_matrix_rank(device):
12011201
def test_qr(shape, mode, device):
12021202
dtype = dpnp.default_float_type(device)
12031203
count_elems = numpy.prod(shape)
1204-
dpnp_data = dpnp.arange(count_elems, dtype=dtype, device=device).reshape(
1205-
shape
1206-
)
1204+
a = dpnp.arange(count_elems, dtype=dtype, device=device).reshape(shape)
12071205

1208-
expected_queue = dpnp_data.get_array().sycl_queue
1206+
expected_queue = a.get_array().sycl_queue
12091207

12101208
if mode == "r":
1211-
dpnp_r = dpnp.linalg.qr(dpnp_data, mode=mode)
1212-
dpnp_r_queue = dpnp_r.get_array().sycl_queue
1213-
assert_sycl_queue_equal(dpnp_r_queue, expected_queue)
1209+
dp_r = dpnp.linalg.qr(a, mode=mode)
1210+
dp_r_queue = dp_r.get_array().sycl_queue
1211+
assert_sycl_queue_equal(dp_r_queue, expected_queue)
12141212
else:
1215-
dpnp_q, dpnp_r = dpnp.linalg.qr(dpnp_data, mode=mode)
1213+
dp_q, dp_r = dpnp.linalg.qr(a, mode=mode)
12161214

1217-
dpnp_q_queue = dpnp_q.get_array().sycl_queue
1218-
dpnp_r_queue = dpnp_r.get_array().sycl_queue
1215+
dp_q_queue = dp_q.get_array().sycl_queue
1216+
dp_r_queue = dp_r.get_array().sycl_queue
12191217

1220-
assert_sycl_queue_equal(dpnp_q_queue, expected_queue)
1221-
assert_sycl_queue_equal(dpnp_r_queue, expected_queue)
1218+
assert_sycl_queue_equal(dp_q_queue, expected_queue)
1219+
assert_sycl_queue_equal(dp_r_queue, expected_queue)
12221220

12231221

12241222
@pytest.mark.parametrize(

tests/test_usm_type.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,3 +710,40 @@ def test_det(shape, is_empty, usm_type):
710710
det = dp.linalg.det(x)
711711

712712
assert x.usm_type == det.usm_type
713+
714+
715+
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
716+
@pytest.mark.parametrize(
717+
"shape",
718+
[
719+
(4, 4),
720+
(2, 0),
721+
(2, 2, 3),
722+
(0, 2, 3),
723+
(1, 0, 3),
724+
],
725+
ids=[
726+
"(4, 4)",
727+
"(2, 0)",
728+
"(2, 2, 3)",
729+
"(0, 2, 3)",
730+
"(1, 0, 3)",
731+
],
732+
)
733+
@pytest.mark.parametrize(
734+
"mode",
735+
["r", "raw", "complete", "reduced"],
736+
ids=["r", "raw", "complete", "reduced"],
737+
)
738+
def test_qr(shape, mode, usm_type):
739+
count_elems = numpy.prod(shape)
740+
a = dp.arange(count_elems, usm_type=usm_type).reshape(shape)
741+
742+
if mode == "r":
743+
dp_r = dp.linalg.qr(a, mode=mode)
744+
assert a.usm_type == dp_r.usm_type
745+
else:
746+
dp_q, dp_r = dp.linalg.qr(a, mode=mode)
747+
748+
assert a.usm_type == dp_q.usm_type
749+
assert a.usm_type == dp_r.usm_type

0 commit comments

Comments
 (0)