Skip to content

Commit d81b917

Browse files
Added tensor_impl._same_logical_tensors predicate
The predicate determines is argument arrays are the same (same dimension, shape, data type, pointer, strides). Used to determine if copying must be performed in case of overlap to avoid race condition.
1 parent bee2698 commit d81b917

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ using dpctl::tensor::c_contiguous_strides;
6060
using dpctl::tensor::f_contiguous_strides;
6161

6262
using dpctl::tensor::overlap::MemoryOverlap;
63+
using dpctl::tensor::overlap::SameLogicalTensors;
6364

6465
using dpctl::tensor::py_internal::copy_usm_ndarray_into_usm_ndarray;
6566

@@ -338,6 +339,15 @@ PYBIND11_MODULE(_tensor_impl, m)
338339
"Determines if the memory regions indexed by each array overlap",
339340
py::arg("array1"), py::arg("array2"));
340341

342+
auto same_logical_tensors = [](dpctl::tensor::usm_ndarray x1,
343+
dpctl::tensor::usm_ndarray x2) -> bool {
344+
auto const &same_logical_tensors = SameLogicalTensors();
345+
return same_logical_tensors(x1, x2);
346+
};
347+
m.def("_same_logical_tensors", same_logical_tensors,
348+
"Determines if the memory regions indexed by each array are the same",
349+
py::arg("array1"), py::arg("array2"));
350+
341351
m.def("_place", &py_place, "", py::arg("dst"), py::arg("cumsum"),
342352
py::arg("axis_start"), py::arg("axis_end"), py::arg("rhs"),
343353
py::arg("sycl_queue"), py::arg("depends") = py::list());

0 commit comments

Comments
 (0)