|
28 | 28 | #include <cstddef>
|
29 | 29 | #include <vector>
|
30 | 30 |
|
| 31 | +#include "math_utils.hpp" |
| 32 | + |
31 | 33 | namespace dpctl
|
32 | 34 | {
|
33 | 35 | namespace tensor
|
34 | 36 | {
|
35 | 37 | namespace sycl_utils
|
36 | 38 | {
|
| 39 | +namespace detail |
| 40 | +{ |
| 41 | + |
| 42 | +template <typename...> struct TypeList; |
| 43 | + |
| 44 | +template <typename Head, typename... Tail> struct TypeList<Head, Tail...> |
| 45 | +{ |
| 46 | + using head = Head; |
| 47 | + using tail = TypeList<Tail...>; |
| 48 | +}; |
| 49 | + |
| 50 | +using NullTypeList = TypeList<>; |
| 51 | +template <typename T> |
| 52 | +struct IsNullTypeList : std::conditional_t<std::is_same_v<T, NullTypeList>, |
| 53 | + std::true_type, |
| 54 | + std::false_type> |
| 55 | +{ |
| 56 | +}; |
| 57 | + |
| 58 | +// recursively check if type is contained in given TypeList |
| 59 | +template <typename T, typename TList> |
| 60 | +struct IsContained |
| 61 | + : std::conditional_t< |
| 62 | + std::is_same_v<typename TList::head, std::remove_cv_t<T>>, |
| 63 | + std::true_type, |
| 64 | + IsContained<T, typename TList::tail>> |
| 65 | +{ |
| 66 | +}; |
| 67 | + |
| 68 | +template <> struct TypeList<> |
| 69 | +{ |
| 70 | +}; |
| 71 | + |
| 72 | +// std::false_type when last case has been checked for membership |
| 73 | +template <typename T> struct IsContained<T, NullTypeList> : std::false_type |
| 74 | +{ |
| 75 | +}; |
| 76 | + |
| 77 | +template <class T> struct IsComplex : std::false_type |
| 78 | +{ |
| 79 | +}; |
| 80 | +template <class T> struct IsComplex<std::complex<T>> : std::true_type |
| 81 | +{ |
| 82 | +}; |
| 83 | + |
| 84 | +} // namespace detail |
| 85 | + |
| 86 | +template <typename T> |
| 87 | +using sycl_ops = detail::TypeList<sycl::plus<T>, |
| 88 | + sycl::bit_or<T>, |
| 89 | + sycl::bit_xor<T>, |
| 90 | + sycl::bit_and<T>, |
| 91 | + sycl::maximum<T>, |
| 92 | + sycl::minimum<T>, |
| 93 | + sycl::multiplies<T>>; |
| 94 | + |
| 95 | +template <typename T, typename Op> struct IsSyclOp |
| 96 | +{ |
| 97 | + static constexpr bool value = |
| 98 | + detail::IsContained<Op, sycl_ops<std::remove_const_t<T>>>::value || |
| 99 | + detail::IsContained<Op, sycl_ops<std::add_const_t<T>>>::value || |
| 100 | + detail::IsContained<Op, sycl_ops<void>>::value; |
| 101 | +}; |
| 102 | + |
| 103 | +struct AtomicSupport |
| 104 | +{ |
| 105 | + bool operator()(const sycl::queue &exec_q, |
| 106 | + sycl::usm::alloc usm_alloc_type, |
| 107 | + bool require_atomic64 = false) const |
| 108 | + { |
| 109 | + bool supports_atomics = false; |
| 110 | + |
| 111 | + const sycl::device &dev = exec_q.get_device(); |
| 112 | + if (require_atomic64) { |
| 113 | + if (!dev.has(sycl::aspect::atomic64)) |
| 114 | + return false; |
| 115 | + } |
| 116 | + |
| 117 | + switch (usm_alloc_type) { |
| 118 | + case sycl::usm::alloc::shared: |
| 119 | + supports_atomics = |
| 120 | + dev.has(sycl::aspect::usm_atomic_shared_allocations); |
| 121 | + break; |
| 122 | + case sycl::usm::alloc::host: |
| 123 | + supports_atomics = |
| 124 | + dev.has(sycl::aspect::usm_atomic_host_allocations); |
| 125 | + break; |
| 126 | + case sycl::usm::alloc::device: |
| 127 | + supports_atomics = true; |
| 128 | + break; |
| 129 | + default: |
| 130 | + supports_atomics = false; |
| 131 | + } |
| 132 | + |
| 133 | + return supports_atomics; |
| 134 | + } |
| 135 | +}; |
37 | 136 |
|
38 | 137 | /*! @brief Find the smallest multiple of supported sub-group size larger than
|
39 | 138 | * nelems */
|
@@ -66,6 +165,172 @@ size_t choose_workgroup_size(const size_t nelems,
|
66 | 165 | return wg;
|
67 | 166 | }
|
68 | 167 |
|
| 168 | +template <typename T, typename GroupT, typename LocAccT, typename OpT> |
| 169 | +T custom_reduce_over_group(GroupT wg, |
| 170 | + LocAccT local_mem_acc, |
| 171 | + T local_val, |
| 172 | + OpT op) |
| 173 | +{ |
| 174 | + size_t wgs = wg.get_local_linear_range(); |
| 175 | + local_mem_acc[wg.get_local_linear_id()] = local_val; |
| 176 | + |
| 177 | + sycl::group_barrier(wg, sycl::memory_scope::work_group); |
| 178 | + |
| 179 | + T red_val_over_wg = local_mem_acc[0]; |
| 180 | + if (wg.leader()) { |
| 181 | + for (size_t i = 1; i < wgs; ++i) { |
| 182 | + red_val_over_wg = op(red_val_over_wg, local_mem_acc[i]); |
| 183 | + } |
| 184 | + } |
| 185 | + |
| 186 | + sycl::group_barrier(wg, sycl::memory_scope::work_group); |
| 187 | + |
| 188 | + return sycl::group_broadcast(wg, red_val_over_wg); |
| 189 | +} |
| 190 | + |
| 191 | +// Reduction functors |
| 192 | + |
| 193 | +// Maximum |
| 194 | + |
| 195 | +template <typename T> struct Maximum |
| 196 | +{ |
| 197 | + T operator()(const T &x, const T &y) const |
| 198 | + { |
| 199 | + if constexpr (detail::IsComplex<T>::value) { |
| 200 | + using dpctl::tensor::math_utils::max_complex; |
| 201 | + return max_complex<T>(x, y); |
| 202 | + } |
| 203 | + else if constexpr (std::is_floating_point_v<T> || |
| 204 | + std::is_same_v<T, sycl::half>) { |
| 205 | + return (std::isnan(x) || x > y) ? x : y; |
| 206 | + } |
| 207 | + else if constexpr (std::is_same_v<T, bool>) { |
| 208 | + return x || y; |
| 209 | + } |
| 210 | + else { |
| 211 | + return (x > y) ? x : y; |
| 212 | + } |
| 213 | + } |
| 214 | +}; |
| 215 | + |
| 216 | +// Minimum |
| 217 | + |
| 218 | +template <typename T> struct Minimum |
| 219 | +{ |
| 220 | + T operator()(const T &x, const T &y) const |
| 221 | + { |
| 222 | + if constexpr (detail::IsComplex<T>::value) { |
| 223 | + using dpctl::tensor::math_utils::min_complex; |
| 224 | + return min_complex<T>(x, y); |
| 225 | + } |
| 226 | + else if constexpr (std::is_floating_point_v<T> || |
| 227 | + std::is_same_v<T, sycl::half>) { |
| 228 | + return (std::isnan(x) || x < y) ? x : y; |
| 229 | + } |
| 230 | + else if constexpr (std::is_same_v<T, bool>) { |
| 231 | + return x && y; |
| 232 | + } |
| 233 | + else { |
| 234 | + return (x < y) ? x : y; |
| 235 | + } |
| 236 | + } |
| 237 | +}; |
| 238 | + |
| 239 | +// Define identities and operator checking structs |
| 240 | + |
| 241 | +template <typename Op, typename T, typename = void> struct GetIdentity |
| 242 | +{ |
| 243 | +}; |
| 244 | + |
| 245 | +// Maximum |
| 246 | + |
| 247 | +template <typename T, class Op> |
| 248 | +using IsMaximum = std::bool_constant<std::is_same_v<Op, sycl::maximum<T>> || |
| 249 | + std::is_same_v<Op, sycl::maximum<void>> || |
| 250 | + std::is_same_v<Op, Maximum<T>> || |
| 251 | + std::is_same_v<Op, Maximum<void>>>; |
| 252 | + |
| 253 | +template <typename Op, typename T> |
| 254 | +struct GetIdentity<Op, T, std::enable_if_t<IsMaximum<T, Op>::value>> |
| 255 | +{ |
| 256 | + static constexpr T value = |
| 257 | + static_cast<T>(std::numeric_limits<T>::has_infinity |
| 258 | + ? static_cast<T>(-std::numeric_limits<T>::infinity()) |
| 259 | + : std::numeric_limits<T>::lowest()); |
| 260 | +}; |
| 261 | + |
| 262 | +template <typename Op> |
| 263 | +struct GetIdentity<Op, bool, std::enable_if_t<IsMaximum<bool, Op>::value>> |
| 264 | +{ |
| 265 | + static constexpr bool value = false; |
| 266 | +}; |
| 267 | + |
| 268 | +template <typename Op, typename T> |
| 269 | +struct GetIdentity<Op, |
| 270 | + std::complex<T>, |
| 271 | + std::enable_if_t<IsMaximum<std::complex<T>, Op>::value>> |
| 272 | +{ |
| 273 | + static constexpr std::complex<T> value{-std::numeric_limits<T>::infinity(), |
| 274 | + -std::numeric_limits<T>::infinity()}; |
| 275 | +}; |
| 276 | + |
| 277 | +// Minimum |
| 278 | + |
| 279 | +template <typename T, class Op> |
| 280 | +using IsMinimum = std::bool_constant<std::is_same_v<Op, sycl::minimum<T>> || |
| 281 | + std::is_same_v<Op, sycl::minimum<void>> || |
| 282 | + std::is_same_v<Op, Minimum<T>> || |
| 283 | + std::is_same_v<Op, Minimum<void>>>; |
| 284 | + |
| 285 | +template <typename Op, typename T> |
| 286 | +struct GetIdentity<Op, T, std::enable_if_t<IsMinimum<T, Op>::value>> |
| 287 | +{ |
| 288 | + static constexpr T value = |
| 289 | + static_cast<T>(std::numeric_limits<T>::has_infinity |
| 290 | + ? static_cast<T>(std::numeric_limits<T>::infinity()) |
| 291 | + : std::numeric_limits<T>::max()); |
| 292 | +}; |
| 293 | + |
| 294 | +template <typename Op> |
| 295 | +struct GetIdentity<Op, bool, std::enable_if_t<IsMinimum<bool, Op>::value>> |
| 296 | +{ |
| 297 | + static constexpr bool value = true; |
| 298 | +}; |
| 299 | + |
| 300 | +template <typename Op, typename T> |
| 301 | +struct GetIdentity<Op, |
| 302 | + std::complex<T>, |
| 303 | + std::enable_if_t<IsMinimum<std::complex<T>, Op>::value>> |
| 304 | +{ |
| 305 | + static constexpr std::complex<T> value{std::numeric_limits<T>::infinity(), |
| 306 | + std::numeric_limits<T>::infinity()}; |
| 307 | +}; |
| 308 | + |
| 309 | +// Plus |
| 310 | + |
| 311 | +template <typename T, class Op> |
| 312 | +using IsPlus = std::bool_constant< |
| 313 | + std::is_same_v<Op, sycl::plus<T>> || std::is_same_v<Op, sycl::plus<void>> || |
| 314 | + std::is_same_v<Op, std::plus<T>> || std::is_same_v<Op, std::plus<T>>>; |
| 315 | + |
| 316 | +// Identity |
| 317 | + |
| 318 | +template <typename Op, typename T, typename = void> struct Identity |
| 319 | +{ |
| 320 | +}; |
| 321 | + |
| 322 | +template <typename Op, typename T> |
| 323 | +struct Identity<Op, T, std::enable_if_t<!IsSyclOp<T, Op>::value>> |
| 324 | +{ |
| 325 | + static constexpr T value = GetIdentity<Op, T>::value; |
| 326 | +}; |
| 327 | + |
| 328 | +template <typename Op, typename T> |
| 329 | +struct Identity<Op, T, std::enable_if_t<IsSyclOp<T, Op>::value>> |
| 330 | +{ |
| 331 | + static constexpr T value = sycl::known_identity<Op, T>::value; |
| 332 | +}; |
| 333 | + |
69 | 334 | } // namespace sycl_utils
|
70 | 335 | } // namespace tensor
|
71 | 336 | } // namespace dpctl
|
0 commit comments