@@ -46,51 +46,6 @@ namespace tensor
46
46
namespace py_internal
47
47
{
48
48
49
- /* @brief Split shape/strides into dir1 (complementary to axis_start <= i <
50
- * axis_end) and dir2 (along given set of axes)
51
- */
52
- template <typename shT>
53
- void _split_iteration_space (const shT &shape_vec,
54
- const shT &strides_vec,
55
- int axis_start,
56
- int axis_end,
57
- shT &dir1_shape_vec,
58
- shT &dir2_shape_vec,
59
- shT &dir1_strides_vec,
60
- shT &dir2_strides_vec)
61
- {
62
- int nd = static_cast <int >(shape_vec.size ());
63
- int dir2_sz = axis_end - axis_start;
64
- int dir1_sz = nd - dir2_sz;
65
-
66
- assert (dir1_sz > 0 );
67
- assert (dir2_sz > 0 );
68
-
69
- dir1_shape_vec.resize (dir1_sz);
70
- dir2_shape_vec.resize (dir2_sz);
71
-
72
- std::copy (shape_vec.begin (), shape_vec.begin () + axis_start,
73
- dir1_shape_vec.begin ());
74
- std::copy (shape_vec.begin () + axis_end, shape_vec.end (),
75
- dir1_shape_vec.begin () + axis_start);
76
-
77
- std::copy (shape_vec.begin () + axis_start, shape_vec.begin () + axis_end,
78
- dir2_shape_vec.begin ());
79
-
80
- dir1_strides_vec.resize (dir1_sz);
81
- dir2_strides_vec.resize (dir2_sz);
82
-
83
- std::copy (strides_vec.begin (), strides_vec.begin () + axis_start,
84
- dir1_strides_vec.begin ());
85
- std::copy (strides_vec.begin () + axis_end, strides_vec.end (),
86
- dir1_strides_vec.begin () + axis_start);
87
-
88
- std::copy (strides_vec.begin () + axis_start, strides_vec.begin () + axis_end,
89
- dir2_strides_vec.begin ());
90
-
91
- return ;
92
- }
93
-
94
49
// Masked extraction
95
50
96
51
namespace td_ns = dpctl::tensor::type_dispatch;
@@ -334,19 +289,21 @@ py_extract(dpctl::tensor::usm_ndarray src,
334
289
shT masked_src_shape;
335
290
shT ortho_src_strides;
336
291
shT masked_src_strides;
337
- _split_iteration_space (src_shape_vec, src_strides_vec, axis_start,
338
- axis_end, ortho_src_shape,
339
- masked_src_shape, // 4 vectors modified
340
- ortho_src_strides, masked_src_strides);
292
+ dpctl::tensor::py_internal::split_iteration_space (
293
+ src_shape_vec, src_strides_vec, axis_start, axis_end,
294
+ ortho_src_shape,
295
+ masked_src_shape, // 4 vectors modified
296
+ ortho_src_strides, masked_src_strides);
341
297
342
298
shT ortho_dst_shape;
343
299
shT masked_dst_shape;
344
300
shT ortho_dst_strides;
345
301
shT masked_dst_strides;
346
- _split_iteration_space (dst_shape_vec, dst_strides_vec, axis_start,
347
- axis_start + 1 , ortho_dst_shape,
348
- masked_dst_shape, // 4 vectors modified
349
- ortho_dst_strides, masked_dst_strides);
302
+ dpctl::tensor::py_internal::split_iteration_space (
303
+ dst_shape_vec, dst_strides_vec, axis_start, axis_start + 1 ,
304
+ ortho_dst_shape,
305
+ masked_dst_shape, // 4 vectors modified
306
+ ortho_dst_strides, masked_dst_strides);
350
307
351
308
assert (ortho_src_shape.size () == static_cast <size_t >(ortho_nd));
352
309
assert (ortho_dst_shape.size () == static_cast <size_t >(ortho_nd));
@@ -662,19 +619,21 @@ py_place(dpctl::tensor::usm_ndarray dst,
662
619
shT masked_dst_shape;
663
620
shT ortho_dst_strides;
664
621
shT masked_dst_strides;
665
- _split_iteration_space (dst_shape_vec, dst_strides_vec, axis_start,
666
- axis_end, ortho_dst_shape,
667
- masked_dst_shape, // 4 vectors modified
668
- ortho_dst_strides, masked_dst_strides);
622
+ dpctl::tensor::py_internal::split_iteration_space (
623
+ dst_shape_vec, dst_strides_vec, axis_start, axis_end,
624
+ ortho_dst_shape,
625
+ masked_dst_shape, // 4 vectors modified
626
+ ortho_dst_strides, masked_dst_strides);
669
627
670
628
shT ortho_rhs_shape;
671
629
shT masked_rhs_shape;
672
630
shT ortho_rhs_strides;
673
631
shT masked_rhs_strides;
674
- _split_iteration_space (rhs_shape_vec, rhs_strides_vec, axis_start,
675
- axis_start + 1 , ortho_rhs_shape,
676
- masked_rhs_shape, // 4 vectors modified
677
- ortho_rhs_strides, masked_rhs_strides);
632
+ dpctl::tensor::py_internal::split_iteration_space (
633
+ rhs_shape_vec, rhs_strides_vec, axis_start, axis_start + 1 ,
634
+ ortho_rhs_shape,
635
+ masked_rhs_shape, // 4 vectors modified
636
+ ortho_rhs_strides, masked_rhs_strides);
678
637
679
638
assert (ortho_dst_shape.size () == static_cast <size_t >(ortho_nd));
680
639
assert (ortho_rhs_shape.size () == static_cast <size_t >(ortho_nd));
0 commit comments