Skip to content

Commit 08bba11

Browse files
Moved definition of tril/triu into dedicated file
1 parent 785670e commit 08bba11

File tree

4 files changed

+355
-248
lines changed

4 files changed

+355
-248
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pybind11_add_module(${python_module_name} MODULE
2525
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
2626
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
2727
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
28+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp
2829
)
2930
target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)
3031
target_include_directories(${python_module_name}

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 15 additions & 248 deletions
Original file line numberDiff line numberDiff line change
@@ -33,33 +33,25 @@
3333
#include <type_traits>
3434

3535
#include "dpctl4pybind11.hpp"
36-
#include "kernels/constructors.hpp"
37-
#include "kernels/copy_and_cast.hpp"
38-
#include "utils/strided_iters.hpp"
39-
#include "utils/type_dispatch.hpp"
40-
#include "utils/type_utils.hpp"
4136

4237
#include "copy_and_cast_usm_to_usm.hpp"
4338
#include "copy_for_reshape.hpp"
4439
#include "copy_numpy_ndarray_into_usm_ndarray.hpp"
4540
#include "eye_ctor.hpp"
4641
#include "full_ctor.hpp"
4742
#include "linear_sequences.hpp"
48-
#include "simplify_iteration_space.hpp"
43+
#include "triul_ctor.hpp"
44+
#include "utils/strided_iters.hpp"
4945

5046
namespace py = pybind11;
51-
namespace _ns = dpctl::tensor::detail;
5247

5348
namespace
5449
{
5550

5651
using dpctl::tensor::c_contiguous_strides;
5752
using dpctl::tensor::f_contiguous_strides;
5853

59-
using dpctl::utils::keep_args_alive;
60-
6154
using dpctl::tensor::py_internal::copy_usm_ndarray_into_usm_ndarray;
62-
using dpctl::tensor::py_internal::simplify_iteration_space;
6355

6456
/* =========================== Copy for reshape ============================= */
6557

@@ -84,253 +76,28 @@ using dpctl::tensor::py_internal::usm_ndarray_eye;
8476

8577
/* =========================== Tril and triu ============================== */
8678

87-
using dpctl::tensor::kernels::constructors::tri_fn_ptr_t;
88-
89-
static tri_fn_ptr_t tril_generic_dispatch_vector[_ns::num_types];
90-
static tri_fn_ptr_t triu_generic_dispatch_vector[_ns::num_types];
91-
92-
std::pair<sycl::event, sycl::event>
93-
tri(sycl::queue &exec_q,
94-
dpctl::tensor::usm_ndarray src,
95-
dpctl::tensor::usm_ndarray dst,
96-
char part,
97-
py::ssize_t k = 0,
98-
const std::vector<sycl::event> &depends = {})
99-
{
100-
// array dimensions must be the same
101-
int src_nd = src.get_ndim();
102-
int dst_nd = dst.get_ndim();
103-
if (src_nd != dst_nd) {
104-
throw py::value_error("Array dimensions are not the same.");
105-
}
106-
107-
if (src_nd < 2) {
108-
throw py::value_error("Array dimensions less than 2.");
109-
}
110-
111-
// shapes must be the same
112-
const py::ssize_t *src_shape = src.get_shape_raw();
113-
const py::ssize_t *dst_shape = dst.get_shape_raw();
114-
115-
bool shapes_equal(true);
116-
size_t src_nelems(1);
117-
118-
for (int i = 0; shapes_equal && i < src_nd; ++i) {
119-
src_nelems *= static_cast<size_t>(src_shape[i]);
120-
shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
121-
}
122-
if (!shapes_equal) {
123-
throw py::value_error("Array shapes are not the same.");
124-
}
125-
126-
if (src_nelems == 0) {
127-
// nothing to do
128-
return std::make_pair(sycl::event(), sycl::event());
129-
}
130-
131-
char *src_data = src.get_data();
132-
char *dst_data = dst.get_data();
133-
134-
// check that arrays do not overlap, and concurrent copying is safe.
135-
auto src_offsets = src.get_minmax_offsets();
136-
auto dst_offsets = dst.get_minmax_offsets();
137-
int src_elem_size = src.get_elemsize();
138-
int dst_elem_size = dst.get_elemsize();
139-
140-
bool memory_overlap =
141-
((dst_data - src_data > src_offsets.second * src_elem_size -
142-
dst_offsets.first * dst_elem_size) &&
143-
(src_data - dst_data > dst_offsets.second * dst_elem_size -
144-
src_offsets.first * src_elem_size));
145-
if (memory_overlap) {
146-
// TODO: could use a temporary, but this is done by the caller
147-
throw py::value_error("Arrays index overlapping segments of memory");
148-
}
149-
150-
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
151-
152-
int src_typenum = src.get_typenum();
153-
int dst_typenum = dst.get_typenum();
154-
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
155-
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
156-
157-
if (dst_typeid != src_typeid) {
158-
throw py::value_error("Array dtype are not the same.");
159-
}
160-
161-
// check same contexts
162-
sycl::queue src_q = src.get_queue();
163-
sycl::queue dst_q = dst.get_queue();
164-
165-
if (!dpctl::utils::queues_are_compatible(exec_q, {src_q, dst_q})) {
166-
throw py::value_error(
167-
"Execution queue context is not the same as allocation contexts");
168-
}
169-
170-
using shT = std::vector<py::ssize_t>;
171-
shT src_strides(src_nd);
172-
173-
bool is_src_c_contig = src.is_c_contiguous();
174-
bool is_src_f_contig = src.is_f_contiguous();
175-
176-
const py::ssize_t *src_strides_raw = src.get_strides_raw();
177-
if (src_strides_raw == nullptr) {
178-
if (is_src_c_contig) {
179-
src_strides = c_contiguous_strides(src_nd, src_shape);
180-
}
181-
else if (is_src_f_contig) {
182-
src_strides = f_contiguous_strides(src_nd, src_shape);
183-
}
184-
else {
185-
throw std::runtime_error("Source array has null strides but has "
186-
"neither C- nor F- contiguous flag set");
187-
}
188-
}
189-
else {
190-
std::copy(src_strides_raw, src_strides_raw + src_nd,
191-
src_strides.begin());
192-
}
193-
194-
shT dst_strides(src_nd);
195-
196-
bool is_dst_c_contig = dst.is_c_contiguous();
197-
bool is_dst_f_contig = dst.is_f_contiguous();
198-
199-
const py::ssize_t *dst_strides_raw = dst.get_strides_raw();
200-
if (dst_strides_raw == nullptr) {
201-
if (is_dst_c_contig) {
202-
dst_strides =
203-
dpctl::tensor::c_contiguous_strides(src_nd, src_shape);
204-
}
205-
else if (is_dst_f_contig) {
206-
dst_strides =
207-
dpctl::tensor::f_contiguous_strides(src_nd, src_shape);
208-
}
209-
else {
210-
throw std::runtime_error("Source array has null strides but has "
211-
"neither C- nor F- contiguous flag set");
212-
}
213-
}
214-
else {
215-
std::copy(dst_strides_raw, dst_strides_raw + dst_nd,
216-
dst_strides.begin());
217-
}
218-
219-
shT simplified_shape;
220-
shT simplified_src_strides;
221-
shT simplified_dst_strides;
222-
py::ssize_t src_offset(0);
223-
py::ssize_t dst_offset(0);
224-
225-
constexpr py::ssize_t src_itemsize = 1; // item size in elements
226-
constexpr py::ssize_t dst_itemsize = 1; // item size in elements
227-
228-
int nd = src_nd - 2;
229-
const py::ssize_t *shape = src_shape;
230-
const py::ssize_t *p_src_strides = src_strides.data();
231-
const py::ssize_t *p_dst_strides = dst_strides.data();
232-
233-
simplify_iteration_space(nd, shape, p_src_strides, src_itemsize,
234-
is_src_c_contig, is_src_f_contig, p_dst_strides,
235-
dst_itemsize, is_dst_c_contig, is_dst_f_contig,
236-
simplified_shape, simplified_src_strides,
237-
simplified_dst_strides, src_offset, dst_offset);
238-
239-
if (src_offset != 0 || dst_offset != 0) {
240-
throw py::value_error("Reversed slice for dst is not supported");
241-
}
242-
243-
nd += 2;
244-
245-
using usm_host_allocatorT =
246-
sycl::usm_allocator<py::ssize_t, sycl::usm::alloc::host>;
247-
using usmshT = std::vector<py::ssize_t, usm_host_allocatorT>;
248-
249-
usm_host_allocatorT allocator(exec_q);
250-
auto shp_host_shape_and_strides =
251-
std::make_shared<usmshT>(3 * nd, allocator);
252-
253-
std::copy(simplified_shape.begin(), simplified_shape.end(),
254-
shp_host_shape_and_strides->begin());
255-
(*shp_host_shape_and_strides)[nd - 2] = src_shape[src_nd - 2];
256-
(*shp_host_shape_and_strides)[nd - 1] = src_shape[src_nd - 1];
257-
258-
std::copy(simplified_src_strides.begin(), simplified_src_strides.end(),
259-
shp_host_shape_and_strides->begin() + nd);
260-
(*shp_host_shape_and_strides)[2 * nd - 2] = src_strides[src_nd - 2];
261-
(*shp_host_shape_and_strides)[2 * nd - 1] = src_strides[src_nd - 1];
262-
263-
std::copy(simplified_dst_strides.begin(), simplified_dst_strides.end(),
264-
shp_host_shape_and_strides->begin() + 2 * nd);
265-
(*shp_host_shape_and_strides)[3 * nd - 2] = dst_strides[src_nd - 2];
266-
(*shp_host_shape_and_strides)[3 * nd - 1] = dst_strides[src_nd - 1];
267-
268-
py::ssize_t *dev_shape_and_strides =
269-
sycl::malloc_device<ssize_t>(3 * nd, exec_q);
270-
if (dev_shape_and_strides == nullptr) {
271-
throw std::runtime_error("Unabled to allocate device memory");
272-
}
273-
sycl::event copy_shape_and_strides = exec_q.copy<ssize_t>(
274-
shp_host_shape_and_strides->data(), dev_shape_and_strides, 3 * nd);
275-
276-
py::ssize_t inner_range = src_shape[src_nd - 1] * src_shape[src_nd - 2];
277-
py::ssize_t outer_range = src_nelems / inner_range;
278-
279-
sycl::event tri_ev;
280-
if (part == 'l') {
281-
auto fn = tril_generic_dispatch_vector[src_typeid];
282-
tri_ev =
283-
fn(exec_q, inner_range, outer_range, src_data, dst_data, nd,
284-
dev_shape_and_strides, k, depends, {copy_shape_and_strides});
285-
}
286-
else {
287-
auto fn = triu_generic_dispatch_vector[src_typeid];
288-
tri_ev =
289-
fn(exec_q, inner_range, outer_range, src_data, dst_data, nd,
290-
dev_shape_and_strides, k, depends, {copy_shape_and_strides});
291-
}
292-
293-
exec_q.submit([&](sycl::handler &cgh) {
294-
cgh.depends_on({tri_ev});
295-
auto ctx = exec_q.get_context();
296-
cgh.host_task(
297-
[shp_host_shape_and_strides, dev_shape_and_strides, ctx]() {
298-
// capture of shp_host_shape_and_strides ensure the underlying
299-
// vector exists for the entire execution of copying kernel
300-
sycl::free(dev_shape_and_strides, ctx);
301-
});
302-
});
303-
304-
return std::make_pair(keep_args_alive(exec_q, {src, dst}, {tri_ev}),
305-
tri_ev);
306-
}
79+
using dpctl::tensor::py_internal::usm_ndarray_triul;
30780

30881
// populate dispatch tables
30982
void init_dispatch_tables(void)
31083
{
311-
dpctl::tensor::py_internal::init_copy_and_cast_usm_to_usm_dispatch_tables();
312-
dpctl::tensor::py_internal::
313-
init_copy_numpy_ndarray_into_usm_ndarray_dispatch_tables();
84+
using namespace dpctl::tensor::py_internal;
85+
86+
init_copy_and_cast_usm_to_usm_dispatch_tables();
87+
init_copy_numpy_ndarray_into_usm_ndarray_dispatch_tables();
31488
return;
31589
}
31690

31791
// populate dispatch vectors
31892
void init_dispatch_vectors(void)
31993
{
320-
dpctl::tensor::py_internal::init_copy_for_reshape_dispatch_vectors();
321-
dpctl::tensor::py_internal::init_linear_sequences_dispatch_vectors();
322-
dpctl::tensor::py_internal::init_full_ctor_dispatch_vectors();
323-
dpctl::tensor::py_internal::init_eye_ctor_dispatch_vectors();
324-
325-
using namespace dpctl::tensor::detail;
326-
using dpctl::tensor::kernels::constructors::TrilGenericFactory;
327-
using dpctl::tensor::kernels::constructors::TriuGenericFactory;
328-
329-
DispatchVectorBuilder<tri_fn_ptr_t, TrilGenericFactory, num_types> dvb5;
330-
dvb5.populate_dispatch_vector(tril_generic_dispatch_vector);
94+
using namespace dpctl::tensor::py_internal;
33195

332-
DispatchVectorBuilder<tri_fn_ptr_t, TriuGenericFactory, num_types> dvb6;
333-
dvb6.populate_dispatch_vector(triu_generic_dispatch_vector);
96+
init_copy_for_reshape_dispatch_vectors();
97+
init_linear_sequences_dispatch_vectors();
98+
init_full_ctor_dispatch_vectors();
99+
init_eye_ctor_dispatch_vectors();
100+
init_triul_ctor_dispatch_vectors();
334101

335102
return;
336103
}
@@ -478,7 +245,7 @@ PYBIND11_MODULE(_tensor_impl, m)
478245
py::ssize_t k, sycl::queue exec_q,
479246
const std::vector<sycl::event> depends)
480247
-> std::pair<sycl::event, sycl::event> {
481-
return tri(exec_q, src, dst, 'l', k, depends);
248+
return usm_ndarray_triul(exec_q, src, dst, 'l', k, depends);
482249
},
483250
"Tril helper function.", py::arg("src"), py::arg("dst"),
484251
py::arg("k") = 0, py::arg("sycl_queue"),
@@ -490,7 +257,7 @@ PYBIND11_MODULE(_tensor_impl, m)
490257
py::ssize_t k, sycl::queue exec_q,
491258
const std::vector<sycl::event> depends)
492259
-> std::pair<sycl::event, sycl::event> {
493-
return tri(exec_q, src, dst, 'u', k, depends);
260+
return usm_ndarray_triul(exec_q, src, dst, 'u', k, depends);
494261
},
495262
"Triu helper function.", py::arg("src"), py::arg("dst"),
496263
py::arg("k") = 0, py::arg("sycl_queue"),

0 commit comments

Comments
 (0)