Skip to content

Commit 9e88171

Browse files
committed
Factored mask_positions implementation and kernels into separate files
Doing this will make implementing more accumulators convenient
1 parent 9f98baf commit 9e88171

File tree

8 files changed

+616
-464
lines changed

8 files changed

+616
-464
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ endif()
3333
set(python_module_name _tensor_impl)
3434
pybind11_add_module(${python_module_name} MODULE
3535
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_py.cpp
36+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators.cpp
3637
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
3738
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp
3839
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
//=== accumulators.hpp - Implementation of accumulator kernels --*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for accumulators (cumulative sum, prod, etc.).
23+
//===---------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <CL/sycl.hpp>
27+
#include <array>
28+
#include <cstdint>
29+
#include <limits>
30+
#include <pybind11/pybind11.h>
31+
#include <utility>
32+
#include <vector>
33+
34+
#include "utils/offset_utils.hpp"
35+
#include "utils/type_dispatch.hpp"
36+
37+
namespace dpctl
38+
{
39+
namespace tensor
40+
{
41+
namespace kernels
42+
{
43+
namespace accumulators
44+
{
45+
46+
namespace py = pybind11;
47+
48+
using namespace dpctl::tensor::offset_utils;
49+
50+
template <typename T> T ceiling_quotient(T n, T m)
51+
{
52+
return (n + m - 1) / m;
53+
}
54+
template <typename T1, typename T2> T1 ceiling_quotient(T1 n, T2 m)
55+
{
56+
return ceiling_quotient<T1>(n, static_cast<T1>(m));
57+
}
58+
59+
template <typename inputT,
60+
typename outputT,
61+
size_t n_wi,
62+
typename IndexerT,
63+
typename TransformerT>
64+
class inclusive_scan_rec_local_scan_krn;
65+
66+
template <typename inputT,
67+
typename outputT,
68+
typename IndexerT,
69+
typename TransformerT>
70+
class inclusive_scan_rec_chunk_update_krn;
71+
72+
template <typename inputT, typename outputT> struct NonZeroIndicator
73+
{
74+
NonZeroIndicator() {}
75+
76+
outputT operator()(const inputT &val) const
77+
{
78+
constexpr outputT out_one(1);
79+
constexpr outputT out_zero(0);
80+
constexpr inputT val_zero(0);
81+
82+
return (val == val_zero) ? out_zero : out_one;
83+
}
84+
};
85+
86+
template <typename T> struct NoOpTransformer
87+
{
88+
NoOpTransformer() {}
89+
90+
T operator()(const T &val) const
91+
{
92+
return val;
93+
}
94+
};
95+
96+
/*
97+
* for integer type maskT,
98+
* output[j] = sum( input[s0 + i * s1], 0 <= i <= j)
99+
* for 0 <= j < n_elems
100+
*/
101+
template <typename inputT,
102+
typename outputT,
103+
size_t n_wi,
104+
typename IndexerT,
105+
typename TransformerT>
106+
sycl::event inclusive_scan_rec(sycl::queue exec_q,
107+
size_t n_elems,
108+
size_t wg_size,
109+
const inputT *input,
110+
outputT *output,
111+
size_t s0,
112+
size_t s1,
113+
IndexerT indexer,
114+
TransformerT transformer,
115+
std::vector<sycl::event> const &depends = {})
116+
{
117+
size_t n_groups = ceiling_quotient(n_elems, n_wi * wg_size);
118+
119+
sycl::event inc_scan_phase1_ev = exec_q.submit([&](sycl::handler &cgh) {
120+
cgh.depends_on(depends);
121+
122+
using slmT = sycl::local_accessor<size_t, 1>;
123+
124+
auto lws = sycl::range<1>(wg_size);
125+
auto gws = sycl::range<1>(n_groups * wg_size);
126+
127+
slmT slm_iscan_tmp(lws, cgh);
128+
129+
cgh.parallel_for<class inclusive_scan_rec_local_scan_krn<
130+
inputT, outputT, n_wi, IndexerT, decltype(transformer)>>(
131+
sycl::nd_range<1>(gws, lws), [=](sycl::nd_item<1> it)
132+
{
133+
auto chunk_gid = it.get_global_id(0);
134+
auto lid = it.get_local_id(0);
135+
136+
std::array<size_t, n_wi> local_isum;
137+
138+
size_t i = chunk_gid * n_wi;
139+
for (size_t m_wi = 0; m_wi < n_wi; ++m_wi) {
140+
constexpr outputT out_zero(0);
141+
142+
local_isum[m_wi] =
143+
(i + m_wi < n_elems)
144+
? transformer(input[indexer(s0 + s1 * (i + m_wi))])
145+
: out_zero;
146+
}
147+
148+
// local_isum is now result of
149+
// inclusive scan of locally stored mask indicators
150+
#pragma unroll
151+
for (size_t m_wi = 1; m_wi < n_wi; ++m_wi) {
152+
local_isum[m_wi] += local_isum[m_wi - 1];
153+
}
154+
155+
size_t wg_iscan_val =
156+
sycl::inclusive_scan_over_group(it.get_group(),
157+
local_isum.back(),
158+
sycl::plus<size_t>(),
159+
size_t(0));
160+
161+
slm_iscan_tmp[(lid + 1) % wg_size] = wg_iscan_val;
162+
it.barrier(sycl::access::fence_space::local_space);
163+
size_t addand = (lid == 0) ? 0 : slm_iscan_tmp[lid];
164+
it.barrier(sycl::access::fence_space::local_space);
165+
166+
#pragma unroll
167+
for (size_t m_wi = 0; m_wi < n_wi; ++m_wi) {
168+
local_isum[m_wi] += addand;
169+
}
170+
171+
for (size_t m_wi = 0; m_wi < n_wi && i + m_wi < n_elems; ++m_wi) {
172+
output[i + m_wi] = local_isum[m_wi];
173+
}
174+
});
175+
});
176+
177+
sycl::event out_event = inc_scan_phase1_ev;
178+
if (n_groups > 1) {
179+
outputT *temp = sycl::malloc_device<outputT>(n_groups - 1, exec_q);
180+
181+
auto chunk_size = wg_size * n_wi;
182+
183+
NoOpIndexer _no_op_indexer{};
184+
NoOpTransformer<outputT> _no_op_transformer{};
185+
auto e2 = inclusive_scan_rec<outputT, outputT, n_wi, NoOpIndexer,
186+
decltype(_no_op_transformer)>(
187+
exec_q, n_groups - 1, wg_size, output, temp, chunk_size - 1,
188+
chunk_size, _no_op_indexer, _no_op_transformer,
189+
{inc_scan_phase1_ev});
190+
191+
// output[ chunk_size * (i + 1) + j] += temp[i]
192+
auto e3 = exec_q.submit([&](sycl::handler &cgh) {
193+
cgh.depends_on(e2);
194+
cgh.parallel_for<class inclusive_scan_rec_chunk_update_krn<
195+
inputT, outputT, IndexerT, decltype(transformer)>>(
196+
{n_elems}, [=](auto wiid)
197+
{
198+
auto gid = wiid[0];
199+
auto i = (gid / chunk_size);
200+
output[gid] += (i > 0) ? temp[i - 1] : 0;
201+
});
202+
});
203+
204+
sycl::event e4 = exec_q.submit([&](sycl::handler &cgh) {
205+
cgh.depends_on(e3);
206+
auto ctx = exec_q.get_context();
207+
cgh.host_task([ctx, temp]() { sycl::free(temp, ctx); });
208+
});
209+
210+
out_event = e4;
211+
}
212+
213+
return out_event;
214+
}
215+
216+
// mask positions
217+
218+
typedef size_t (*mask_positions_contig_impl_fn_ptr_t)(
219+
sycl::queue,
220+
size_t,
221+
const char *,
222+
char *,
223+
std::vector<sycl::event> const &);
224+
225+
template <typename maskT, typename cumsumT>
226+
size_t mask_positions_contig_impl(sycl::queue q,
227+
size_t n_elems,
228+
const char *mask,
229+
char *cumsum,
230+
std::vector<sycl::event> const &depends = {})
231+
{
232+
constexpr int n_wi = 8;
233+
const maskT *mask_data_ptr = reinterpret_cast<const maskT *>(mask);
234+
cumsumT *cumsum_data_ptr = reinterpret_cast<cumsumT *>(cumsum);
235+
size_t wg_size = 128;
236+
237+
NoOpIndexer flat_indexer{};
238+
NonZeroIndicator<maskT, cumsumT> non_zero_indicator{};
239+
240+
sycl::event comp_ev =
241+
inclusive_scan_rec<maskT, cumsumT, n_wi, decltype(flat_indexer),
242+
decltype(non_zero_indicator)>(
243+
q, n_elems, wg_size, mask_data_ptr, cumsum_data_ptr, 0, 1,
244+
flat_indexer, non_zero_indicator, depends);
245+
246+
cumsumT *last_elem = cumsum_data_ptr + (n_elems - 1);
247+
248+
cumsumT *last_elem_host_usm = sycl::malloc_host<cumsumT>(1, q);
249+
250+
if (last_elem_host_usm == nullptr) {
251+
throw std::bad_alloc();
252+
}
253+
sycl::event copy_e =
254+
q.copy<cumsumT>(last_elem, last_elem_host_usm, 1, {comp_ev});
255+
copy_e.wait();
256+
size_t return_val = static_cast<size_t>(*last_elem_host_usm);
257+
sycl::free(last_elem_host_usm, q);
258+
259+
return return_val;
260+
}
261+
262+
template <typename fnT, typename T> struct MaskPositionsContigFactoryForInt32
263+
{
264+
fnT get()
265+
{
266+
fnT fn = mask_positions_contig_impl<T, std::int32_t>;
267+
return fn;
268+
}
269+
};
270+
271+
template <typename fnT, typename T> struct MaskPositionsContigFactoryForInt64
272+
{
273+
fnT get()
274+
{
275+
fnT fn = mask_positions_contig_impl<T, std::int64_t>;
276+
return fn;
277+
}
278+
};
279+
280+
typedef size_t (*mask_positions_strided_impl_fn_ptr_t)(
281+
sycl::queue,
282+
size_t,
283+
const char *,
284+
int,
285+
const py::ssize_t *,
286+
char *,
287+
std::vector<sycl::event> const &);
288+
289+
template <typename maskT, typename cumsumT>
290+
size_t mask_positions_strided_impl(sycl::queue q,
291+
size_t n_elems,
292+
const char *mask,
293+
int nd,
294+
const py::ssize_t *shape_strides,
295+
char *cumsum,
296+
std::vector<sycl::event> const &depends = {})
297+
{
298+
constexpr int n_wi = 8;
299+
const maskT *mask_data_ptr = reinterpret_cast<const maskT *>(mask);
300+
cumsumT *cumsum_data_ptr = reinterpret_cast<cumsumT *>(cumsum);
301+
size_t wg_size = 128;
302+
303+
StridedIndexer strided_indexer{nd, 0, shape_strides};
304+
NonZeroIndicator<maskT, cumsumT> non_zero_indicator{};
305+
306+
sycl::event comp_ev =
307+
inclusive_scan_rec<maskT, cumsumT, n_wi, decltype(strided_indexer),
308+
decltype(non_zero_indicator)>(
309+
q, n_elems, wg_size, mask_data_ptr, cumsum_data_ptr, 0, 1,
310+
strided_indexer, non_zero_indicator, depends);
311+
312+
cumsumT *last_elem = cumsum_data_ptr + (n_elems - 1);
313+
314+
cumsumT *last_elem_host_usm = sycl::malloc_host<cumsumT>(1, q);
315+
316+
if (last_elem_host_usm == nullptr) {
317+
throw std::bad_alloc();
318+
}
319+
sycl::event copy_e =
320+
q.copy<cumsumT>(last_elem, last_elem_host_usm, 1, {comp_ev});
321+
copy_e.wait();
322+
size_t return_val = static_cast<size_t>(*last_elem_host_usm);
323+
sycl::free(last_elem_host_usm, q);
324+
325+
return return_val;
326+
}
327+
328+
template <typename fnT, typename T> struct MaskPositionsStridedFactoryForInt32
329+
{
330+
fnT get()
331+
{
332+
fnT fn = mask_positions_strided_impl<T, std::int32_t>;
333+
return fn;
334+
}
335+
};
336+
337+
template <typename fnT, typename T> struct MaskPositionsStridedFactoryForInt64
338+
{
339+
fnT get()
340+
{
341+
fnT fn = mask_positions_strided_impl<T, std::int64_t>;
342+
return fn;
343+
}
344+
};
345+
346+
} // namespace accumulators
347+
} // namespace kernels
348+
} // namespace tensor
349+
} // namespace dpctl

0 commit comments

Comments
 (0)