Skip to content

Commit 393f8d5

Browse files
Merge branch 'master' into remove_dot_and_multiply_from_backend
2 parents 30fbf2b + b990bac commit 393f8d5

32 files changed

+2353
-1476
lines changed

dpnp/backend/extensions/statistics/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp
3030
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
3131
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
32+
${CMAKE_CURRENT_SOURCE_DIR}/histogramdd.cpp
3233
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
3334
${CMAKE_CURRENT_SOURCE_DIR}/sliding_dot_product1d.cpp
3435
${CMAKE_CURRENT_SOURCE_DIR}/sliding_window1d.cpp

dpnp/backend/extensions/statistics/bincount.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525

2626
#pragma once
2727

28-
#include <dpctl4pybind11.hpp>
28+
#include <pybind11/pybind11.h>
2929
#include <sycl/sycl.hpp>
3030

3131
#include "dispatch_table.hpp"
32+
#include "dpctl4pybind11.hpp"
3233

3334
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3435

dpnp/backend/extensions/statistics/common.hpp

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,12 @@
2828
#include <complex>
2929
#include <pybind11/numpy.h>
3030
#include <pybind11/pybind11.h>
31-
32-
// clang-format off
33-
// math_utils.hpp doesn't include sycl header but uses sycl types
34-
// so sycl.hpp must be included before math_utils.hpp
3531
#include <sycl/sycl.hpp>
32+
3633
#include "utils/math_utils.hpp"
37-
// clang-format on
34+
#include "utils/type_utils.hpp"
35+
36+
namespace type_utils = dpctl::tensor::type_utils;
3837

3938
namespace statistics::common
4039
{
@@ -54,24 +53,20 @@ constexpr auto Align(N n, D d)
5453
template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
5554
struct AtomicOp
5655
{
57-
static void add(T &lhs, const T value)
56+
static void add(T &lhs, const T &value)
5857
{
59-
sycl::atomic_ref<T, Order, Scope> lh(lhs);
60-
lh += value;
61-
}
62-
};
58+
if constexpr (type_utils::is_complex_v<T>) {
59+
using vT = typename T::value_type;
60+
vT *_lhs = reinterpret_cast<vT(&)[2]>(lhs);
61+
const vT *_val = reinterpret_cast<const vT(&)[2]>(value);
6362

64-
template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
65-
struct AtomicOp<std::complex<T>, Order, Scope>
66-
{
67-
static void add(std::complex<T> &lhs, const std::complex<T> value)
68-
{
69-
T *_lhs = reinterpret_cast<T(&)[2]>(lhs);
70-
const T *_val = reinterpret_cast<const T(&)[2]>(value);
71-
sycl::atomic_ref<T, Order, Scope> lh0(_lhs[0]);
72-
lh0 += _val[0];
73-
sycl::atomic_ref<T, Order, Scope> lh1(_lhs[1]);
74-
lh1 += _val[1];
63+
AtomicOp<vT, Order, Scope>::add(_lhs[0], _val[0]);
64+
AtomicOp<vT, Order, Scope>::add(_lhs[1], _val[1]);
65+
}
66+
else {
67+
sycl::atomic_ref<T, Order, Scope> lh(lhs);
68+
lh += value;
69+
}
7570
}
7671
};
7772

@@ -80,17 +75,12 @@ struct Less
8075
{
8176
bool operator()(const T &lhs, const T &rhs) const
8277
{
83-
return std::less{}(lhs, rhs);
84-
}
85-
};
86-
87-
template <typename T>
88-
struct Less<std::complex<T>>
89-
{
90-
bool operator()(const std::complex<T> &lhs,
91-
const std::complex<T> &rhs) const
92-
{
93-
return dpctl::tensor::math_utils::less_complex(lhs, rhs);
78+
if constexpr (type_utils::is_complex_v<T>) {
79+
return dpctl::tensor::math_utils::less_complex(lhs, rhs);
80+
}
81+
else {
82+
return std::less{}(lhs, rhs);
83+
}
9484
}
9585
};
9686

@@ -99,26 +89,23 @@ struct IsNan
9989
{
10090
static bool isnan(const T &v)
10191
{
102-
if constexpr (std::is_floating_point_v<T> ||
103-
std::is_same_v<T, sycl::half>) {
92+
if constexpr (type_utils::is_complex_v<T>) {
93+
using vT = typename T::value_type;
94+
95+
const vT real1 = std::real(v);
96+
const vT imag1 = std::imag(v);
97+
98+
return IsNan<vT>::isnan(real1) || IsNan<vT>::isnan(imag1);
99+
}
100+
else if constexpr (std::is_floating_point_v<T> ||
101+
std::is_same_v<T, sycl::half>) {
104102
return sycl::isnan(v);
105103
}
106104

107105
return false;
108106
}
109107
};
110108

111-
template <typename T>
112-
struct IsNan<std::complex<T>>
113-
{
114-
static bool isnan(const std::complex<T> &v)
115-
{
116-
T real1 = std::real(v);
117-
T imag1 = std::imag(v);
118-
return sycl::isnan(real1) || sycl::isnan(imag1);
119-
}
120-
};
121-
122109
size_t get_max_local_size(const sycl::device &device);
123110
size_t get_max_local_size(const sycl::device &device,
124111
int cpu_local_size_limit,

dpnp/backend/extensions/statistics/histogram.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525

2626
#pragma once
2727

28+
#include <pybind11/pybind11.h>
2829
#include <sycl/sycl.hpp>
2930

3031
#include "dispatch_table.hpp"
31-
32-
// namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
32+
#include "dpctl4pybind11.hpp"
3333

3434
namespace statistics::histogram
3535
{

dpnp/backend/extensions/statistics/histogram_common.cpp

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,14 @@ void validate(const usm_ndarray &sample,
8888
{&histogram}, names);
8989

9090
check_size_at_least(bins_ptr, 2, names);
91-
9291
check_size_at_least(&histogram, 1, names);
93-
check_num_dims(&histogram, 1, names);
9492

9593
if (weights_ptr) {
9694
check_num_dims(weights_ptr, 1, names);
9795

98-
auto sample_size = sample.get_size();
96+
auto sample_size = sample.get_shape(0);
9997
auto weights_size = weights_ptr->get_size();
100-
if (sample.get_size() != weights_ptr->get_size()) {
98+
if (sample_size != weights_ptr->get_size()) {
10199
throw py::value_error(name_of(&sample, names) + " size (" +
102100
std::to_string(sample_size) + ") and " +
103101
name_of(weights_ptr, names) + " size (" +
@@ -110,61 +108,74 @@ void validate(const usm_ndarray &sample,
110108

111109
if (sample.get_ndim() == 1) {
112110
check_num_dims(bins_ptr, 1, names);
111+
112+
if (bins_ptr && histogram.get_size() != bins_ptr->get_size() - 1) {
113+
auto hist_size = histogram.get_size();
114+
auto bins_size = bins_ptr->get_size();
115+
throw py::value_error(
116+
name_of(&histogram, names) + " parameter and " +
117+
name_of(bins_ptr, names) + " parameters shape mismatch. " +
118+
name_of(&histogram, names) + " size is " +
119+
std::to_string(hist_size) + name_of(bins_ptr, names) +
120+
" must have size " + std::to_string(hist_size + 1) +
121+
" but have " + std::to_string(bins_size));
122+
}
113123
}
114124
else if (sample.get_ndim() == 2) {
115125
auto sample_count = sample.get_shape(0);
116126
auto expected_dims = sample.get_shape(1);
117127

118-
if (bins_ptr != nullptr && bins_ptr->get_ndim() != expected_dims) {
128+
if (histogram.get_ndim() != expected_dims) {
119129
throw py::value_error(
120-
name_of(&sample, names) + " parameter has shape {" +
121-
std::to_string(sample_count) + "x" +
122-
std::to_string(expected_dims) + "}" + ", so " +
123-
name_of(bins_ptr, names) + " parameter expected to be " +
130+
name_of(&sample, names) + " parameter has shape (" +
131+
std::to_string(sample_count) + ", " +
132+
std::to_string(expected_dims) + ")" + ", so " +
133+
name_of(&histogram, names) + " parameter expected to be " +
124134
std::to_string(expected_dims) +
125135
"d. "
126136
"Actual " +
127-
std::to_string(bins->get_ndim()) + "d");
137+
std::to_string(histogram.get_ndim()) + "d");
128138
}
129-
}
130139

131-
if (bins_ptr != nullptr) {
132-
py::ssize_t expected_hist_size = 1;
133-
for (int i = 0; i < bins_ptr->get_ndim(); ++i) {
134-
expected_hist_size *= (bins_ptr->get_shape(i) - 1);
140+
if (bins_ptr != nullptr) {
141+
py::ssize_t expected_bins_size = 0;
142+
for (int i = 0; i < histogram.get_ndim(); ++i) {
143+
expected_bins_size += histogram.get_shape(i) + 1;
144+
}
145+
146+
auto actual_bins_size = bins_ptr->get_size();
147+
if (actual_bins_size != expected_bins_size) {
148+
throw py::value_error(
149+
name_of(&histogram, names) + " and " +
150+
name_of(bins_ptr, names) + " shape mismatch. " +
151+
name_of(bins_ptr, names) + " expected to have size = " +
152+
std::to_string(expected_bins_size) + ". Actual " +
153+
std::to_string(actual_bins_size));
154+
}
135155
}
136156

137-
if (histogram.get_size() != expected_hist_size) {
138-
throw py::value_error(
139-
name_of(&histogram, names) + " and " +
140-
name_of(bins_ptr, names) + " shape mismatch. " +
141-
name_of(&histogram, names) + " expected to have size = " +
142-
std::to_string(expected_hist_size) + ". Actual " +
143-
std::to_string(histogram.get_size()));
157+
int64_t max_hist_size = std::numeric_limits<uint32_t>::max() - 1;
158+
if (histogram.get_size() > max_hist_size) {
159+
throw py::value_error(name_of(&histogram, names) +
160+
" parameter size expected to be less than " +
161+
std::to_string(max_hist_size) + ". Actual " +
162+
std::to_string(histogram.get_size()));
144163
}
145-
}
146-
147-
int64_t max_hist_size = std::numeric_limits<uint32_t>::max() - 1;
148-
if (histogram.get_size() > max_hist_size) {
149-
throw py::value_error(name_of(&histogram, names) +
150-
" parameter size expected to be less than " +
151-
std::to_string(max_hist_size) + ". Actual " +
152-
std::to_string(histogram.get_size()));
153-
}
154164

155-
auto array_types = dpctl_td_ns::usm_ndarray_types();
156-
auto hist_type = static_cast<typenum_t>(
157-
array_types.typenum_to_lookup_id(histogram.get_typenum()));
158-
if (histogram.get_elemsize() == 8 && hist_type != typenum_t::CFLOAT) {
159-
auto device = exec_q.get_device();
160-
bool _64bit_atomics = device.has(sycl::aspect::atomic64);
161-
162-
if (!_64bit_atomics) {
163-
auto device_name = device.get_info<sycl::info::device::name>();
164-
throw py::value_error(
165-
name_of(&histogram, names) +
166-
" parameter has 64-bit type, but 64-bit atomics " +
167-
" are not supported for " + device_name);
165+
auto array_types = dpctl_td_ns::usm_ndarray_types();
166+
auto hist_type = static_cast<typenum_t>(
167+
array_types.typenum_to_lookup_id(histogram.get_typenum()));
168+
if (histogram.get_elemsize() == 8 && hist_type != typenum_t::CFLOAT) {
169+
auto device = exec_q.get_device();
170+
bool _64bit_atomics = device.has(sycl::aspect::atomic64);
171+
172+
if (!_64bit_atomics) {
173+
auto device_name = device.get_info<sycl::info::device::name>();
174+
throw py::value_error(
175+
name_of(&histogram, names) +
176+
" parameter has 64-bit type, but 64-bit atomics " +
177+
" are not supported for " + device_name);
178+
}
168179
}
169180
}
170181
}

0 commit comments

Comments
 (0)