Skip to content

Commit 7ed9c7e

Browse files
Initial commit
1 parent 316240c commit 7ed9c7e

File tree

5 files changed

+796
-0
lines changed

5 files changed

+796
-0
lines changed

dpnp/backend/extensions/statistics/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ set(_module_src
3030
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
3131
${CMAKE_CURRENT_SOURCE_DIR}/histogramdd.cpp
3232
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
33+
${CMAKE_CURRENT_SOURCE_DIR}/kth_element1d.cpp
3334
${CMAKE_CURRENT_SOURCE_DIR}/sliding_dot_product1d.cpp
3435
${CMAKE_CURRENT_SOURCE_DIR}/sliding_window1d.cpp
3536
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
Lines changed: 381 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,381 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024-2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <cmath>
27+
#include <complex>
28+
#include <memory>
29+
#include <vector>
30+
31+
#include <pybind11/pybind11.h>
32+
#include <pybind11/stl.h>
33+
34+
// dpctl tensor headers
35+
#include "dpctl4pybind11.hpp"
36+
#include "utils/sycl_alloc_utils.hpp"
37+
#include "utils/type_dispatch.hpp"
38+
39+
#include "ext/common.hpp"
40+
#include "kth_element1d.hpp"
41+
#include "partitioning.hpp"
42+
43+
// #include <iostream>
44+
45+
namespace sycl_exp = sycl::ext::oneapi::experimental;
46+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
47+
namespace dpctl_utils = dpctl::tensor::alloc_utils;
48+
49+
using dpctl::tensor::usm_ndarray;
50+
51+
using namespace statistics::partitioning;
52+
using namespace ext::common;
53+
54+
namespace
55+
{
56+
57+
template <typename T>
58+
struct pick_pivot_kernel;
59+
60+
template <typename T>
61+
struct KthElementF
62+
{
63+
static sycl::event run_pick_pivot(sycl::queue &queue,
64+
T *in,
65+
T *out,
66+
uint64_t target,
67+
State<T> &state,
68+
uint64_t items_to_sort,
69+
uint64_t limit,
70+
const std::vector<sycl::event> &deps)
71+
{
72+
auto e = queue.submit([&](sycl::handler &cgh) {
73+
cgh.depends_on(deps);
74+
constexpr uint64_t group_size = 128;
75+
76+
auto work_sz = make_ndrange(group_size, group_size, 1);
77+
78+
size_t temp_memory_size =
79+
sycl_exp::default_sorters::joint_sorter<>::memory_required<T>(
80+
sycl::memory_scope::work_group, limit);
81+
82+
auto loc_items =
83+
sycl::local_accessor<T, 1>(sycl::range<1>(items_to_sort), cgh);
84+
auto scratch = sycl::local_accessor<std::byte, 1>(
85+
sycl::range<1>(temp_memory_size), cgh);
86+
87+
cgh.parallel_for<pick_pivot_kernel<T>>(
88+
work_sz, [=](sycl::nd_item<1> item) {
89+
auto group = item.get_group();
90+
91+
if (state.stop[0])
92+
return;
93+
94+
auto llid = item.get_local_linear_id();
95+
auto local_size = item.get_group_range(0);
96+
97+
uint64_t num_elems = 0;
98+
bool target_found = false;
99+
100+
T *_in = nullptr;
101+
if (group.leader()) {
102+
state.update_counters();
103+
auto less_count = state.counters.less_count[0];
104+
bool left = target < less_count;
105+
state.left[0] = left;
106+
107+
if (left) {
108+
_in = in;
109+
num_elems = state.iteration_counters.less_count[0];
110+
if (target + 1 == less_count) {
111+
_in[num_elems] = state.pivot[0];
112+
state.counters.less_count[0] += 1;
113+
num_elems += 1;
114+
}
115+
}
116+
else {
117+
num_elems =
118+
state.iteration_counters.greater_equal_count[0];
119+
_in = in + state.n - num_elems;
120+
121+
if (target + 1 <
122+
less_count +
123+
state.iteration_counters.equal_count[0]) {
124+
state.values[0] = state.pivot[0];
125+
state.values[1] = state.pivot[0];
126+
127+
state.stop[0] = true;
128+
state.target_found[0] = true;
129+
target_found = true;
130+
}
131+
}
132+
133+
state.reset_iteration_counters();
134+
}
135+
136+
target_found =
137+
sycl::group_broadcast(group, target_found, 0);
138+
_in = sycl::group_broadcast(group, _in, 0);
139+
num_elems = sycl::group_broadcast(group, num_elems, 0);
140+
141+
if (target_found) {
142+
return;
143+
}
144+
145+
if (num_elems <= limit) {
146+
auto gh = sycl_exp::group_with_scratchpad(
147+
group, sycl::span{&scratch[0], temp_memory_size});
148+
sycl_exp::joint_sort(gh, &_in[0], &_in[num_elems]);
149+
150+
if (group.leader()) {
151+
uint64_t offset = state.counters.less_count[0];
152+
if (state.left[0]) {
153+
offset =
154+
state.counters.less_count[0] - num_elems;
155+
}
156+
157+
uint64_t idx = target - offset;
158+
state.values[0] = _in[idx];
159+
state.values[1] = _in[idx + 1];
160+
161+
state.stop[0] = true;
162+
state.target_found[0] = true;
163+
}
164+
165+
return;
166+
}
167+
168+
uint64_t step = num_elems / items_to_sort;
169+
for (uint32_t i = llid; i < items_to_sort; i += local_size)
170+
{
171+
loc_items[i] = std::numeric_limits<T>::max();
172+
uint32_t idx = i * step;
173+
if (idx < num_elems) {
174+
loc_items[i] = _in[idx];
175+
}
176+
}
177+
178+
sycl::group_barrier(group);
179+
180+
auto gh = sycl_exp::group_with_scratchpad(
181+
group, sycl::span{&scratch[0], temp_memory_size});
182+
sycl_exp::joint_sort(gh, &loc_items[0],
183+
&loc_items[0] + items_to_sort);
184+
185+
T new_pivot = loc_items[items_to_sort / 2];
186+
187+
if (new_pivot != state.pivot[0]) {
188+
if (group.leader()) {
189+
state.pivot[0] = new_pivot;
190+
state.num_elems[0] = num_elems;
191+
}
192+
return;
193+
}
194+
195+
auto start = llid + items_to_sort / 2 + 1;
196+
uint32_t index = start;
197+
for (uint32_t i = start; i < items_to_sort; i += local_size)
198+
{
199+
if (loc_items[i] != new_pivot) {
200+
index = i;
201+
break;
202+
}
203+
}
204+
205+
index = sycl::reduce_over_group(group, index,
206+
sycl::minimum<>());
207+
if (group.leader()) {
208+
state.pivot[0] = loc_items[index];
209+
state.num_elems[0] = num_elems;
210+
}
211+
});
212+
});
213+
214+
return e;
215+
}
216+
217+
static sycl::event run_partition(sycl::queue &exec_q,
218+
T *in,
219+
T *out,
220+
PartitionState<T> &state,
221+
const std::vector<sycl::event> &deps)
222+
{
223+
224+
uint32_t group_size = 128;
225+
auto e = exec_q.submit([&](sycl::handler &cgh) {
226+
cgh.depends_on(deps);
227+
228+
constexpr uint32_t WorkPI = 4; // empirically found number
229+
230+
auto work_range = make_ndrange(state.n, group_size, WorkPI);
231+
submit_partition_one_pivot<T, WorkPI>(cgh, work_range, in, out,
232+
state);
233+
});
234+
235+
return e;
236+
}
237+
238+
static sycl::event run_kth_element(sycl::queue &exec_q,
239+
const T *in,
240+
T *partitioned,
241+
const size_t k,
242+
State<T> &state,
243+
PartitionState<T> &pstate,
244+
const std::vector<sycl::event> &depends)
245+
{
246+
uint32_t items_to_sort = 128;
247+
uint32_t limit = 4 * items_to_sort;
248+
uint32_t iterations =
249+
std::ceil(std::log(double(state.n) / limit) / std::log(2));
250+
251+
auto temp_buff = dpctl_utils::smart_malloc<T>(state.n, exec_q,
252+
sycl::usm::alloc::device);
253+
254+
auto prev = run_pick_pivot(exec_q, const_cast<T *>(in), partitioned, k,
255+
state, items_to_sort, limit, depends);
256+
prev = run_partition(exec_q, const_cast<T *>(in), partitioned, pstate,
257+
{prev});
258+
259+
T *_in = partitioned;
260+
T *_out = temp_buff.get();
261+
for (uint32_t i = 0; i < iterations - 1; ++i) {
262+
prev = run_pick_pivot(exec_q, _in, _out, k, state, limit,
263+
items_to_sort, {prev});
264+
prev = run_partition(exec_q, _in, _out, pstate, {prev});
265+
std::swap(_in, _out);
266+
}
267+
prev = run_pick_pivot(exec_q, _in, _out, k, state, limit, items_to_sort,
268+
{prev});
269+
270+
return prev;
271+
}
272+
273+
static std::tuple<bool, uint64_t, uint64_t, uint64_t>
274+
impl(sycl::queue &exec_queue,
275+
const void *v_ain,
276+
void *v_partitioned,
277+
const size_t a_size,
278+
const size_t k,
279+
const std::vector<sycl::event> &depends)
280+
{
281+
const T *ain = static_cast<const T *>(v_ain);
282+
T *partitioned = static_cast<T *>(v_partitioned);
283+
284+
State<T> state(exec_queue, a_size, partitioned);
285+
PartitionState<T> pstate(state);
286+
287+
auto init_e = state.init(exec_queue, depends);
288+
init_e = pstate.init(exec_queue, {init_e});
289+
290+
auto evt = run_kth_element(exec_queue, ain, partitioned, k, state,
291+
pstate, {init_e});
292+
293+
bool found = false;
294+
bool left = false;
295+
uint64_t less_count = 0;
296+
uint64_t greater_equal_count = 0;
297+
uint64_t num_elems = 0;
298+
auto copy_evt = exec_queue.copy(state.target_found, &found, 1, evt);
299+
copy_evt = exec_queue.copy(state.left, &left, 1, copy_evt);
300+
copy_evt = exec_queue.copy(state.counters.less_count, &less_count, 1,
301+
copy_evt);
302+
copy_evt = exec_queue.copy(state.counters.greater_equal_count,
303+
&greater_equal_count, 1, copy_evt);
304+
copy_evt = exec_queue.copy(state.num_elems, &num_elems, 1, copy_evt);
305+
306+
copy_evt.wait();
307+
308+
uint64_t buff_offset = 0;
309+
uint64_t elems_offset = less_count;
310+
if (!found) {
311+
if (left) {
312+
elems_offset = less_count - num_elems;
313+
}
314+
else {
315+
buff_offset = a_size - num_elems;
316+
}
317+
}
318+
else {
319+
num_elems = 2;
320+
elems_offset = k;
321+
}
322+
323+
state.cleanup(exec_queue);
324+
325+
return {found, buff_offset, elems_offset, num_elems};
326+
}
327+
};
328+
329+
using SupportedTypes =
330+
std::tuple<uint32_t, int32_t, uint64_t, int64_t, float, double>;
331+
} // namespace
332+
333+
KthElement1d::KthElement1d() : dispatch_table("a")
334+
{
335+
dispatch_table.populate_dispatch_table<SupportedTypes, KthElementF>();
336+
}
337+
338+
std::tuple<bool, uint64_t, uint64_t, uint64_t>
339+
KthElement1d::call(const dpctl::tensor::usm_ndarray &a,
340+
dpctl::tensor::usm_ndarray &partitioned,
341+
const size_t k,
342+
const std::vector<sycl::event> &depends)
343+
{
344+
// validate(a, partitioned, k);
345+
346+
const int a_typenum = a.get_typenum();
347+
auto kth_elem_func = dispatch_table.get(a_typenum);
348+
349+
auto exec_q = a.get_queue();
350+
auto result = kth_elem_func(exec_q, a.get_data(), partitioned.get_data(),
351+
a.get_shape(0), k, depends);
352+
353+
return result;
354+
}
355+
356+
std::unique_ptr<KthElement1d> kth;
357+
358+
void statistics::partitioning::populate_kth_element1d(py::module_ m)
359+
{
360+
using namespace std::placeholders;
361+
362+
kth.reset(new KthElement1d());
363+
364+
auto kth_func = [kthp = kth.get()](
365+
const dpctl::tensor::usm_ndarray &a,
366+
dpctl::tensor::usm_ndarray &partitioned, const size_t k,
367+
const std::vector<sycl::event> &depends) {
368+
return kthp->call(a, partitioned, k, depends);
369+
};
370+
371+
m.def("kth_element", kth_func, "finding k and k+1 elements.", py::arg("a"),
372+
py::arg("partitioned"), py::arg("k"),
373+
py::arg("depends") = py::list());
374+
375+
auto kth_dtypes = [kthp = kth.get()]() {
376+
return kthp->dispatch_table.get_all_supported_types();
377+
};
378+
379+
m.def("kth_element_dtypes", kth_dtypes,
380+
"Get the supported data types for kth_element.");
381+
}

0 commit comments

Comments
 (0)