Skip to content

Commit 21957a9

Browse files
Add queues_are_compatible signare for list of usm_ndarray instances
Instead of call ```c++ queues_are_compatible(exec_q, {X1.get_queue(), X2.get_queue()}) ``` for `usm_ndarray` instances `X1` and `X2`, one can now say ```c++ queues_are_compatible(exec_q, {X1, X2}) ```
1 parent da56dce commit 21957a9

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,8 @@ sycl::event keep_args_alive(sycl::queue q,
987987
return host_task_ev;
988988
}
989989

990+
/*! @brief Check if all allocation queues are the same as the
991+
execution queue */
990992
template <std::size_t num>
991993
bool queues_are_compatible(sycl::queue exec_q,
992994
const sycl::queue (&alloc_qs)[num])
@@ -1000,6 +1002,21 @@ bool queues_are_compatible(sycl::queue exec_q,
10001002
return true;
10011003
}
10021004

1005+
/*! @brief Check if all allocation queues of usm_ndarays are the same as
1006+
the execution queue */
1007+
template <std::size_t num>
1008+
bool queues_are_compatible(sycl::queue exec_q,
1009+
const ::dpctl::tensor::usm_ndarray (&arrs)[num])
1010+
{
1011+
for (std::size_t i = 0; i < num; ++i) {
1012+
1013+
if (exec_q != arrs[i].get_queue()) {
1014+
return false;
1015+
}
1016+
}
1017+
return true;
1018+
}
1019+
10031020
} // end namespace utils
10041021

10051022
} // end namespace dpctl

0 commit comments

Comments
 (0)