33
33
#include < type_traits>
34
34
35
35
#include " dpctl4pybind11.hpp"
36
- #include " kernels/constructors.hpp"
37
- #include " kernels/copy_and_cast.hpp"
38
- #include " utils/strided_iters.hpp"
39
- #include " utils/type_dispatch.hpp"
40
- #include " utils/type_utils.hpp"
41
36
42
37
#include " copy_and_cast_usm_to_usm.hpp"
43
38
#include " copy_for_reshape.hpp"
44
39
#include " copy_numpy_ndarray_into_usm_ndarray.hpp"
45
40
#include " eye_ctor.hpp"
46
41
#include " full_ctor.hpp"
47
42
#include " linear_sequences.hpp"
48
- #include " simplify_iteration_space.hpp"
43
+ #include " triul_ctor.hpp"
44
+ #include " utils/strided_iters.hpp"
49
45
50
46
namespace py = pybind11;
51
- namespace _ns = dpctl::tensor::detail;
52
47
53
48
namespace
54
49
{
55
50
56
51
using dpctl::tensor::c_contiguous_strides;
57
52
using dpctl::tensor::f_contiguous_strides;
58
53
59
- using dpctl::utils::keep_args_alive;
60
-
61
54
using dpctl::tensor::py_internal::copy_usm_ndarray_into_usm_ndarray;
62
- using dpctl::tensor::py_internal::simplify_iteration_space;
63
55
64
56
/* =========================== Copy for reshape ============================= */
65
57
@@ -84,253 +76,28 @@ using dpctl::tensor::py_internal::usm_ndarray_eye;
84
76
85
77
/* =========================== Tril and triu ============================== */
86
78
87
- using dpctl::tensor::kernels::constructors::tri_fn_ptr_t ;
88
-
89
- static tri_fn_ptr_t tril_generic_dispatch_vector[_ns::num_types];
90
- static tri_fn_ptr_t triu_generic_dispatch_vector[_ns::num_types];
91
-
92
- std::pair<sycl::event, sycl::event>
93
- tri (sycl::queue &exec_q,
94
- dpctl::tensor::usm_ndarray src,
95
- dpctl::tensor::usm_ndarray dst,
96
- char part,
97
- py::ssize_t k = 0 ,
98
- const std::vector<sycl::event> &depends = {})
99
- {
100
- // array dimensions must be the same
101
- int src_nd = src.get_ndim ();
102
- int dst_nd = dst.get_ndim ();
103
- if (src_nd != dst_nd) {
104
- throw py::value_error (" Array dimensions are not the same." );
105
- }
106
-
107
- if (src_nd < 2 ) {
108
- throw py::value_error (" Array dimensions less than 2." );
109
- }
110
-
111
- // shapes must be the same
112
- const py::ssize_t *src_shape = src.get_shape_raw ();
113
- const py::ssize_t *dst_shape = dst.get_shape_raw ();
114
-
115
- bool shapes_equal (true );
116
- size_t src_nelems (1 );
117
-
118
- for (int i = 0 ; shapes_equal && i < src_nd; ++i) {
119
- src_nelems *= static_cast <size_t >(src_shape[i]);
120
- shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
121
- }
122
- if (!shapes_equal) {
123
- throw py::value_error (" Array shapes are not the same." );
124
- }
125
-
126
- if (src_nelems == 0 ) {
127
- // nothing to do
128
- return std::make_pair (sycl::event (), sycl::event ());
129
- }
130
-
131
- char *src_data = src.get_data ();
132
- char *dst_data = dst.get_data ();
133
-
134
- // check that arrays do not overlap, and concurrent copying is safe.
135
- auto src_offsets = src.get_minmax_offsets ();
136
- auto dst_offsets = dst.get_minmax_offsets ();
137
- int src_elem_size = src.get_elemsize ();
138
- int dst_elem_size = dst.get_elemsize ();
139
-
140
- bool memory_overlap =
141
- ((dst_data - src_data > src_offsets.second * src_elem_size -
142
- dst_offsets.first * dst_elem_size) &&
143
- (src_data - dst_data > dst_offsets.second * dst_elem_size -
144
- src_offsets.first * src_elem_size));
145
- if (memory_overlap) {
146
- // TODO: could use a temporary, but this is done by the caller
147
- throw py::value_error (" Arrays index overlapping segments of memory" );
148
- }
149
-
150
- auto array_types = dpctl::tensor::detail::usm_ndarray_types ();
151
-
152
- int src_typenum = src.get_typenum ();
153
- int dst_typenum = dst.get_typenum ();
154
- int src_typeid = array_types.typenum_to_lookup_id (src_typenum);
155
- int dst_typeid = array_types.typenum_to_lookup_id (dst_typenum);
156
-
157
- if (dst_typeid != src_typeid) {
158
- throw py::value_error (" Array dtype are not the same." );
159
- }
160
-
161
- // check same contexts
162
- sycl::queue src_q = src.get_queue ();
163
- sycl::queue dst_q = dst.get_queue ();
164
-
165
- if (!dpctl::utils::queues_are_compatible (exec_q, {src_q, dst_q})) {
166
- throw py::value_error (
167
- " Execution queue context is not the same as allocation contexts" );
168
- }
169
-
170
- using shT = std::vector<py::ssize_t >;
171
- shT src_strides (src_nd);
172
-
173
- bool is_src_c_contig = src.is_c_contiguous ();
174
- bool is_src_f_contig = src.is_f_contiguous ();
175
-
176
- const py::ssize_t *src_strides_raw = src.get_strides_raw ();
177
- if (src_strides_raw == nullptr ) {
178
- if (is_src_c_contig) {
179
- src_strides = c_contiguous_strides (src_nd, src_shape);
180
- }
181
- else if (is_src_f_contig) {
182
- src_strides = f_contiguous_strides (src_nd, src_shape);
183
- }
184
- else {
185
- throw std::runtime_error (" Source array has null strides but has "
186
- " neither C- nor F- contiguous flag set" );
187
- }
188
- }
189
- else {
190
- std::copy (src_strides_raw, src_strides_raw + src_nd,
191
- src_strides.begin ());
192
- }
193
-
194
- shT dst_strides (src_nd);
195
-
196
- bool is_dst_c_contig = dst.is_c_contiguous ();
197
- bool is_dst_f_contig = dst.is_f_contiguous ();
198
-
199
- const py::ssize_t *dst_strides_raw = dst.get_strides_raw ();
200
- if (dst_strides_raw == nullptr ) {
201
- if (is_dst_c_contig) {
202
- dst_strides =
203
- dpctl::tensor::c_contiguous_strides (src_nd, src_shape);
204
- }
205
- else if (is_dst_f_contig) {
206
- dst_strides =
207
- dpctl::tensor::f_contiguous_strides (src_nd, src_shape);
208
- }
209
- else {
210
- throw std::runtime_error (" Source array has null strides but has "
211
- " neither C- nor F- contiguous flag set" );
212
- }
213
- }
214
- else {
215
- std::copy (dst_strides_raw, dst_strides_raw + dst_nd,
216
- dst_strides.begin ());
217
- }
218
-
219
- shT simplified_shape;
220
- shT simplified_src_strides;
221
- shT simplified_dst_strides;
222
- py::ssize_t src_offset (0 );
223
- py::ssize_t dst_offset (0 );
224
-
225
- constexpr py::ssize_t src_itemsize = 1 ; // item size in elements
226
- constexpr py::ssize_t dst_itemsize = 1 ; // item size in elements
227
-
228
- int nd = src_nd - 2 ;
229
- const py::ssize_t *shape = src_shape;
230
- const py::ssize_t *p_src_strides = src_strides.data ();
231
- const py::ssize_t *p_dst_strides = dst_strides.data ();
232
-
233
- simplify_iteration_space (nd, shape, p_src_strides, src_itemsize,
234
- is_src_c_contig, is_src_f_contig, p_dst_strides,
235
- dst_itemsize, is_dst_c_contig, is_dst_f_contig,
236
- simplified_shape, simplified_src_strides,
237
- simplified_dst_strides, src_offset, dst_offset);
238
-
239
- if (src_offset != 0 || dst_offset != 0 ) {
240
- throw py::value_error (" Reversed slice for dst is not supported" );
241
- }
242
-
243
- nd += 2 ;
244
-
245
- using usm_host_allocatorT =
246
- sycl::usm_allocator<py::ssize_t , sycl::usm::alloc::host>;
247
- using usmshT = std::vector<py::ssize_t , usm_host_allocatorT>;
248
-
249
- usm_host_allocatorT allocator (exec_q);
250
- auto shp_host_shape_and_strides =
251
- std::make_shared<usmshT>(3 * nd, allocator);
252
-
253
- std::copy (simplified_shape.begin (), simplified_shape.end (),
254
- shp_host_shape_and_strides->begin ());
255
- (*shp_host_shape_and_strides)[nd - 2 ] = src_shape[src_nd - 2 ];
256
- (*shp_host_shape_and_strides)[nd - 1 ] = src_shape[src_nd - 1 ];
257
-
258
- std::copy (simplified_src_strides.begin (), simplified_src_strides.end (),
259
- shp_host_shape_and_strides->begin () + nd);
260
- (*shp_host_shape_and_strides)[2 * nd - 2 ] = src_strides[src_nd - 2 ];
261
- (*shp_host_shape_and_strides)[2 * nd - 1 ] = src_strides[src_nd - 1 ];
262
-
263
- std::copy (simplified_dst_strides.begin (), simplified_dst_strides.end (),
264
- shp_host_shape_and_strides->begin () + 2 * nd);
265
- (*shp_host_shape_and_strides)[3 * nd - 2 ] = dst_strides[src_nd - 2 ];
266
- (*shp_host_shape_and_strides)[3 * nd - 1 ] = dst_strides[src_nd - 1 ];
267
-
268
- py::ssize_t *dev_shape_and_strides =
269
- sycl::malloc_device<ssize_t >(3 * nd, exec_q);
270
- if (dev_shape_and_strides == nullptr ) {
271
- throw std::runtime_error (" Unabled to allocate device memory" );
272
- }
273
- sycl::event copy_shape_and_strides = exec_q.copy <ssize_t >(
274
- shp_host_shape_and_strides->data (), dev_shape_and_strides, 3 * nd);
275
-
276
- py::ssize_t inner_range = src_shape[src_nd - 1 ] * src_shape[src_nd - 2 ];
277
- py::ssize_t outer_range = src_nelems / inner_range;
278
-
279
- sycl::event tri_ev;
280
- if (part == ' l' ) {
281
- auto fn = tril_generic_dispatch_vector[src_typeid];
282
- tri_ev =
283
- fn (exec_q, inner_range, outer_range, src_data, dst_data, nd,
284
- dev_shape_and_strides, k, depends, {copy_shape_and_strides});
285
- }
286
- else {
287
- auto fn = triu_generic_dispatch_vector[src_typeid];
288
- tri_ev =
289
- fn (exec_q, inner_range, outer_range, src_data, dst_data, nd,
290
- dev_shape_and_strides, k, depends, {copy_shape_and_strides});
291
- }
292
-
293
- exec_q.submit ([&](sycl::handler &cgh) {
294
- cgh.depends_on ({tri_ev});
295
- auto ctx = exec_q.get_context ();
296
- cgh.host_task (
297
- [shp_host_shape_and_strides, dev_shape_and_strides, ctx]() {
298
- // capture of shp_host_shape_and_strides ensure the underlying
299
- // vector exists for the entire execution of copying kernel
300
- sycl::free (dev_shape_and_strides, ctx);
301
- });
302
- });
303
-
304
- return std::make_pair (keep_args_alive (exec_q, {src, dst}, {tri_ev}),
305
- tri_ev);
306
- }
79
+ using dpctl::tensor::py_internal::usm_ndarray_triul;
307
80
308
81
// populate dispatch tables
309
82
void init_dispatch_tables (void )
310
83
{
311
- dpctl::tensor::py_internal::init_copy_and_cast_usm_to_usm_dispatch_tables ();
312
- dpctl::tensor::py_internal::
313
- init_copy_numpy_ndarray_into_usm_ndarray_dispatch_tables ();
84
+ using namespace dpctl ::tensor::py_internal;
85
+
86
+ init_copy_and_cast_usm_to_usm_dispatch_tables ();
87
+ init_copy_numpy_ndarray_into_usm_ndarray_dispatch_tables ();
314
88
return ;
315
89
}
316
90
317
91
// populate dispatch vectors
318
92
void init_dispatch_vectors (void )
319
93
{
320
- dpctl::tensor::py_internal::init_copy_for_reshape_dispatch_vectors ();
321
- dpctl::tensor::py_internal::init_linear_sequences_dispatch_vectors ();
322
- dpctl::tensor::py_internal::init_full_ctor_dispatch_vectors ();
323
- dpctl::tensor::py_internal::init_eye_ctor_dispatch_vectors ();
324
-
325
- using namespace dpctl ::tensor::detail;
326
- using dpctl::tensor::kernels::constructors::TrilGenericFactory;
327
- using dpctl::tensor::kernels::constructors::TriuGenericFactory;
328
-
329
- DispatchVectorBuilder<tri_fn_ptr_t , TrilGenericFactory, num_types> dvb5;
330
- dvb5.populate_dispatch_vector (tril_generic_dispatch_vector);
94
+ using namespace dpctl ::tensor::py_internal;
331
95
332
- DispatchVectorBuilder<tri_fn_ptr_t , TriuGenericFactory, num_types> dvb6;
333
- dvb6.populate_dispatch_vector (triu_generic_dispatch_vector);
96
+ init_copy_for_reshape_dispatch_vectors ();
97
+ init_linear_sequences_dispatch_vectors ();
98
+ init_full_ctor_dispatch_vectors ();
99
+ init_eye_ctor_dispatch_vectors ();
100
+ init_triul_ctor_dispatch_vectors ();
334
101
335
102
return ;
336
103
}
@@ -478,7 +245,7 @@ PYBIND11_MODULE(_tensor_impl, m)
478
245
py::ssize_t k, sycl::queue exec_q,
479
246
const std::vector<sycl::event> depends)
480
247
-> std::pair<sycl::event, sycl::event> {
481
- return tri (exec_q, src, dst, ' l' , k, depends);
248
+ return usm_ndarray_triul (exec_q, src, dst, ' l' , k, depends);
482
249
},
483
250
" Tril helper function." , py::arg (" src" ), py::arg (" dst" ),
484
251
py::arg (" k" ) = 0 , py::arg (" sycl_queue" ),
@@ -490,7 +257,7 @@ PYBIND11_MODULE(_tensor_impl, m)
490
257
py::ssize_t k, sycl::queue exec_q,
491
258
const std::vector<sycl::event> depends)
492
259
-> std::pair<sycl::event, sycl::event> {
493
- return tri (exec_q, src, dst, ' u' , k, depends);
260
+ return usm_ndarray_triul (exec_q, src, dst, ' u' , k, depends);
494
261
},
495
262
" Triu helper function." , py::arg (" src" ), py::arg (" dst" ),
496
263
py::arg (" k" ) = 0 , py::arg (" sycl_queue" ),
0 commit comments