Skip to content

Commit 71277b0

Browse files
Update test_qr in test_sycl_queue
1 parent 7e9750a commit 71277b0

File tree

2 files changed

+99
-30
lines changed

2 files changed

+99
-30
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -806,18 +806,57 @@ def dpnp_qr_batch(a, mode="reduced"):
806806
if batch_size == 0 or k == 0:
807807
if mode == "reduced":
808808
return (
809-
dpnp.empty(batch_shape + (m, k), dtype=res_type),
810-
dpnp.empty(batch_shape + (k, n), dtype=res_type),
809+
dpnp.empty(
810+
batch_shape + (m, k),
811+
dtype=res_type,
812+
sycl_queue=a_sycl_queue,
813+
usm_type=a_usm_type,
814+
),
815+
dpnp.empty(
816+
batch_shape + (k, n),
817+
dtype=res_type,
818+
sycl_queue=a_sycl_queue,
819+
usm_type=a_usm_type,
820+
),
811821
)
812822
elif mode == "complete":
813-
q = _stacked_identity(batch_shape, m, dtype=res_type)
814-
return (q, dpnp.empty(batch_shape + (m, n), dtype=res_type))
823+
q = _stacked_identity(
824+
batch_shape,
825+
m,
826+
dtype=res_type,
827+
usm_type=a_usm_type,
828+
sycl_queue=a_sycl_queue,
829+
)
830+
return (
831+
q,
832+
dpnp.empty(
833+
batch_shape + (m, n),
834+
dtype=res_type,
835+
sycl_queue=a_sycl_queue,
836+
usm_type=a_usm_type,
837+
),
838+
)
815839
elif mode == "r":
816-
return dpnp.empty(batch_shape + (k, n), dtype=res_type)
840+
return dpnp.empty(
841+
batch_shape + (k, n),
842+
dtype=res_type,
843+
sycl_queue=a_sycl_queue,
844+
usm_type=a_usm_type,
845+
)
817846
elif mode == "raw":
818847
return (
819-
dpnp.empty(batch_shape + (n, m), dtype=res_type),
820-
dpnp.empty(batch_shape + (k,), dtype=res_type),
848+
dpnp.empty(
849+
batch_shape + (n, m),
850+
dtype=res_type,
851+
sycl_queue=a_sycl_queue,
852+
usm_type=a_usm_type,
853+
),
854+
dpnp.empty(
855+
batch_shape + (k,),
856+
dtype=res_type,
857+
sycl_queue=a_sycl_queue,
858+
usm_type=a_usm_type,
859+
),
821860
)
822861

823862
# get 3d input arrays by reshape
@@ -826,7 +865,7 @@ def dpnp_qr_batch(a, mode="reduced"):
826865
a = a.swapaxes(-2, -1)
827866
a_usm_arr = dpnp.get_usm_ndarray(a)
828867

829-
a_t = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=a_usm_type)
868+
a_t = dpnp.empty_like(a, order="C", dtype=res_type)
830869

831870
# use DPCTL tensor function to fill the matrix array
832871
# with content from the input array `a`
@@ -876,10 +915,20 @@ def dpnp_qr_batch(a, mode="reduced"):
876915

877916
if mode == "complete" and m > n:
878917
mc = m
879-
q = dpnp.empty((batch_size, m, m), dtype=res_type)
918+
q = dpnp.empty(
919+
(batch_size, m, m),
920+
dtype=res_type,
921+
sycl_queue=a_sycl_queue,
922+
usm_type=a_usm_type,
923+
)
880924
else:
881925
mc = k
882-
q = dpnp.empty((batch_size, n, m), dtype=res_type)
926+
q = dpnp.empty(
927+
(batch_size, n, m),
928+
dtype=res_type,
929+
sycl_queue=a_sycl_queue,
930+
usm_type=a_usm_type,
931+
)
883932
q[..., :n, :] = a_t
884933

885934
q_stride = q.strides[0]
@@ -977,7 +1026,7 @@ def dpnp_qr(a, mode="reduced"):
9771026

9781027
a = a.T
9791028
a_usm_arr = dpnp.get_usm_ndarray(a)
980-
a_t = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=a_usm_type)
1029+
a_t = dpnp.empty_like(a, order="C", dtype=res_type)
9811030

9821031
# use DPCTL tensor function to fill the matrix array
9831032
# with content from the input array `a`

tests/test_sycl_queue.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,34 +1171,54 @@ def test_matrix_rank(device):
11711171
assert_array_equal(expected, result)
11721172

11731173

1174+
@pytest.mark.parametrize(
1175+
"shape",
1176+
[
1177+
(4, 4),
1178+
(2, 0),
1179+
(2, 2, 3),
1180+
(0, 2, 3),
1181+
(1, 0, 3),
1182+
],
1183+
ids=[
1184+
"(4, 4)",
1185+
"(2, 0)",
1186+
"(2, 2, 3)",
1187+
"(0, 2, 3)",
1188+
"(1, 0, 3)",
1189+
],
1190+
)
1191+
@pytest.mark.parametrize(
1192+
"mode",
1193+
["r", "raw", "complete", "reduced"],
1194+
ids=["r", "raw", "complete", "reduced"],
1195+
)
11741196
@pytest.mark.parametrize(
11751197
"device",
11761198
valid_devices,
11771199
ids=[device.filter_string for device in valid_devices],
11781200
)
1179-
def test_qr(device):
1180-
data = [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]
1181-
dpnp_data = dpnp.array(data, device=device)
1182-
numpy_data = numpy.array(data, dtype=dpnp_data.dtype)
1183-
1184-
np_q, np_r = numpy.linalg.qr(numpy_data, "reduced")
1185-
dpnp_q, dpnp_r = dpnp.linalg.qr(dpnp_data, "reduced")
1201+
def test_qr(shape, mode, device):
1202+
dtype = dpnp.default_float_type(device)
1203+
count_elems = numpy.prod(shape)
1204+
dpnp_data = dpnp.arange(count_elems, dtype=dtype, device=device).reshape(
1205+
shape
1206+
)
11861207

1187-
assert dpnp_q.dtype == np_q.dtype
1188-
assert dpnp_r.dtype == np_r.dtype
1189-
assert dpnp_q.shape == np_q.shape
1190-
assert dpnp_r.shape == np_r.shape
1208+
expected_queue = dpnp_data.get_array().sycl_queue
11911209

1192-
assert_dtype_allclose(dpnp_q, np_q)
1193-
assert_dtype_allclose(dpnp_r, np_r)
1210+
if mode == "r":
1211+
dpnp_r = dpnp.linalg.qr(dpnp_data, mode=mode)
1212+
dpnp_r_queue = dpnp_r.get_array().sycl_queue
1213+
assert_sycl_queue_equal(dpnp_r_queue, expected_queue)
1214+
else:
1215+
dpnp_q, dpnp_r = dpnp.linalg.qr(dpnp_data, mode=mode)
11941216

1195-
expected_queue = dpnp_data.get_array().sycl_queue
1196-
dpnp_q_queue = dpnp_q.get_array().sycl_queue
1197-
dpnp_r_queue = dpnp_r.get_array().sycl_queue
1217+
dpnp_q_queue = dpnp_q.get_array().sycl_queue
1218+
dpnp_r_queue = dpnp_r.get_array().sycl_queue
11981219

1199-
# compare queue and device
1200-
assert_sycl_queue_equal(dpnp_q_queue, expected_queue)
1201-
assert_sycl_queue_equal(dpnp_r_queue, expected_queue)
1220+
assert_sycl_queue_equal(dpnp_q_queue, expected_queue)
1221+
assert_sycl_queue_equal(dpnp_r_queue, expected_queue)
12021222

12031223

12041224
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)