Skip to content

Commit 23b9e04

Browse files
Adjusted tests to work on GPUs w/o double precision support
1 parent caded96 commit 23b9e04

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,11 @@ py_gemv(sycl::queue q,
6262
throw std::runtime_error("Inconsistent shapes.");
6363
}
6464

65-
auto q_ctx = q.get_context();
66-
if (q_ctx != matrix.get_queue().get_context() ||
67-
q_ctx != vector.get_queue().get_context() ||
68-
q_ctx != result.get_queue().get_context())
65+
if (!dpctl::utils::queues_are_compatible(
66+
q, {matrix.get_queue(), vector.get_queue(), result.get_queue()}))
6967
{
7068
throw std::runtime_error(
71-
"USM allocation is not bound to the context in execution queue.");
69+
"USM allocations are not compatible with the execution queue.");
7270
}
7371

7472
auto &api = dpctl::detail::dpctl_capi::get();

examples/pybind11/onemkl_gemv/tests/test_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ def test_gemv():
1818
except dpctl.SyclQueueCreationError:
1919
pytest.skip("Queue could not be created")
2020
Mnp, vnp = np.random.randn(5, 3), np.random.randn(3)
21-
r = dpt.empty((5,), dtype="d", sycl_queue=q)
2221
M = dpt.asarray(Mnp, sycl_queue=q)
2322
v = dpt.asarray(vnp, sycl_queue=q)
23+
r = dpt.empty((5,), dtype=v.dtype, sycl_queue=q)
2424
hev, ev = gemv(q, M, v, r, [])
2525
hev.wait()
2626
rnp = dpt.asnumpy(r)
@@ -33,9 +33,9 @@ def test_sub():
3333
except dpctl.SyclQueueCreationError:
3434
pytest.skip("Queue could not be created")
3535
anp, bnp = np.random.randn(5), np.random.randn(5)
36-
r = dpt.empty((5,), dtype="d", sycl_queue=q)
3736
a = dpt.asarray(anp, sycl_queue=q)
3837
b = dpt.asarray(bnp, sycl_queue=q)
38+
r = dpt.empty((5,), dtype=b.dtype, sycl_queue=q)
3939
hev, ev = sub(q, a, b, r, [])
4040
hev.wait()
4141
rnp = dpt.asnumpy(r)

0 commit comments

Comments
 (0)