|
| 1 | +//===----------------------------------------------------------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#ifndef _LIBCPP___ALGORITHM_PSTL_BACKENDS_CPU_BACKENDS_TRANSFORM_REDUCE_H |
| 10 | +#define _LIBCPP___ALGORITHM_PSTL_BACKENDS_CPU_BACKENDS_TRANSFORM_REDUCE_H |
| 11 | + |
| 12 | +#include <__algorithm/pstl_backends/cpu_backends/backend.h> |
| 13 | +#include <__config> |
| 14 | +#include <__iterator/iterator_traits.h> |
| 15 | +#include <__numeric/transform_reduce.h> |
| 16 | +#include <__type_traits/is_arithmetic.h> |
| 17 | +#include <__type_traits/is_execution_policy.h> |
| 18 | +#include <__type_traits/operation_traits.h> |
| 19 | +#include <__utility/move.h> |
| 20 | +#include <__utility/terminate_on_exception.h> |
| 21 | +#include <new> |
| 22 | + |
| 23 | +#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) |
| 24 | +# pragma GCC system_header |
| 25 | +#endif |
| 26 | + |
| 27 | +#if !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17 |
| 28 | + |
| 29 | +_LIBCPP_BEGIN_NAMESPACE_STD |
| 30 | + |
| 31 | +template < |
| 32 | + typename _DifferenceType, |
| 33 | + typename _Tp, |
| 34 | + typename _BinaryOperation, |
| 35 | + typename _UnaryOperation, |
| 36 | + __enable_if_t<__is_trivial_plus_operation<_BinaryOperation, _Tp, _Tp>::value && is_arithmetic_v<_Tp>, int> = 0> |
| 37 | +_LIBCPP_HIDE_FROM_ABI _Tp |
| 38 | +__simd_transform_reduce(_DifferenceType __n, _Tp __init, _BinaryOperation, _UnaryOperation __f) noexcept { |
| 39 | + _PSTL_PRAGMA_SIMD_REDUCTION(+ : __init) |
| 40 | + for (_DifferenceType __i = 0; __i < __n; ++__i) |
| 41 | + __init += __f(__i); |
| 42 | + return __init; |
| 43 | +} |
| 44 | + |
| 45 | +template < |
| 46 | + typename _Size, |
| 47 | + typename _Tp, |
| 48 | + typename _BinaryOperation, |
| 49 | + typename _UnaryOperation, |
| 50 | + __enable_if_t<!(__is_trivial_plus_operation<_BinaryOperation, _Tp, _Tp>::value && is_arithmetic_v<_Tp>), int> = 0> |
| 51 | +_LIBCPP_HIDE_FROM_ABI _Tp |
| 52 | +__simd_transform_reduce(_Size __n, _Tp __init, _BinaryOperation __binary_op, _UnaryOperation __f) noexcept { |
| 53 | + const _Size __block_size = __lane_size / sizeof(_Tp); |
| 54 | + if (__n > 2 * __block_size && __block_size > 1) { |
| 55 | + alignas(__lane_size) char __lane_buffer[__lane_size]; |
| 56 | + _Tp* __lane = reinterpret_cast<_Tp*>(__lane_buffer); |
| 57 | + |
| 58 | + // initializer |
| 59 | + _PSTL_PRAGMA_SIMD |
| 60 | + for (_Size __i = 0; __i < __block_size; ++__i) { |
| 61 | + ::new (__lane + __i) _Tp(__binary_op(__f(__i), __f(__block_size + __i))); |
| 62 | + } |
| 63 | + // main loop |
| 64 | + _Size __i = 2 * __block_size; |
| 65 | + const _Size __last_iteration = __block_size * (__n / __block_size); |
| 66 | + for (; __i < __last_iteration; __i += __block_size) { |
| 67 | + _PSTL_PRAGMA_SIMD |
| 68 | + for (_Size __j = 0; __j < __block_size; ++__j) { |
| 69 | + __lane[__j] = __binary_op(std::move(__lane[__j]), __f(__i + __j)); |
| 70 | + } |
| 71 | + } |
| 72 | + // remainder |
| 73 | + _PSTL_PRAGMA_SIMD |
| 74 | + for (_Size __j = 0; __j < __n - __last_iteration; ++__j) { |
| 75 | + __lane[__j] = __binary_op(std::move(__lane[__j]), __f(__last_iteration + __j)); |
| 76 | + } |
| 77 | + // combiner |
| 78 | + for (_Size __j = 0; __j < __block_size; ++__j) { |
| 79 | + __init = __binary_op(std::move(__init), std::move(__lane[__j])); |
| 80 | + } |
| 81 | + // destroyer |
| 82 | + _PSTL_PRAGMA_SIMD |
| 83 | + for (_Size __j = 0; __j < __block_size; ++__j) { |
| 84 | + __lane[__j].~_Tp(); |
| 85 | + } |
| 86 | + } else { |
| 87 | + for (_Size __i = 0; __i < __n; ++__i) { |
| 88 | + __init = __binary_op(std::move(__init), __f(__i)); |
| 89 | + } |
| 90 | + } |
| 91 | + return __init; |
| 92 | +} |
| 93 | + |
| 94 | +template <class _ExecutionPolicy, |
| 95 | + class _ForwardIterator1, |
| 96 | + class _ForwardIterator2, |
| 97 | + class _Tp, |
| 98 | + class _BinaryOperation1, |
| 99 | + class _BinaryOperation2> |
| 100 | +_LIBCPP_HIDE_FROM_ABI _Tp __pstl_transform_reduce( |
| 101 | + __cpu_backend_tag, |
| 102 | + _ForwardIterator1 __first1, |
| 103 | + _ForwardIterator1 __last1, |
| 104 | + _ForwardIterator2 __first2, |
| 105 | + _Tp __init, |
| 106 | + _BinaryOperation1 __reduce, |
| 107 | + _BinaryOperation2 __transform) { |
| 108 | + if constexpr (__is_parallel_execution_policy_v<_ExecutionPolicy> && |
| 109 | + __has_random_access_iterator_category<_ForwardIterator1>::value && |
| 110 | + __has_random_access_iterator_category<_ForwardIterator2>::value) { |
| 111 | + return std::__terminate_on_exception([&] { |
| 112 | + return __par_backend::__parallel_transform_reduce( |
| 113 | + __first1, |
| 114 | + std::move(__last1), |
| 115 | + [__first1, __first2, __transform](_ForwardIterator1 __iter) { |
| 116 | + return __transform(*__iter, *(__first2 + (__iter - __first1))); |
| 117 | + }, |
| 118 | + std::move(__init), |
| 119 | + std::move(__reduce), |
| 120 | + [__first1, __first2, __reduce, __transform]( |
| 121 | + _ForwardIterator1 __brick_first, _ForwardIterator1 __brick_last, _Tp __brick_init) { |
| 122 | + return std::__pstl_transform_reduce<__remove_parallel_policy_t<_ExecutionPolicy>>( |
| 123 | + __cpu_backend_tag{}, |
| 124 | + __brick_first, |
| 125 | + std::move(__brick_last), |
| 126 | + __first2 + (__brick_first - __first1), |
| 127 | + std::move(__brick_init), |
| 128 | + std::move(__reduce), |
| 129 | + std::move(__transform)); |
| 130 | + }); |
| 131 | + }); |
| 132 | + } else if constexpr (__is_unsequenced_execution_policy_v<_ExecutionPolicy> && |
| 133 | + __has_random_access_iterator_category<_ForwardIterator1>::value && |
| 134 | + __has_random_access_iterator_category<_ForwardIterator2>::value) { |
| 135 | + return std::__simd_transform_reduce( |
| 136 | + __last1 - __first1, std::move(__init), std::move(__reduce), [&](__iter_diff_t<_ForwardIterator1> __i) { |
| 137 | + return __transform(__first1[__i], __first2[__i]); |
| 138 | + }); |
| 139 | + } else { |
| 140 | + return std::transform_reduce( |
| 141 | + std::move(__first1), |
| 142 | + std::move(__last1), |
| 143 | + std::move(__first2), |
| 144 | + std::move(__init), |
| 145 | + std::move(__reduce), |
| 146 | + std::move(__transform)); |
| 147 | + } |
| 148 | +} |
| 149 | + |
| 150 | +template <class _ExecutionPolicy, class _ForwardIterator, class _Tp, class _BinaryOperation, class _UnaryOperation> |
| 151 | +_LIBCPP_HIDE_FROM_ABI _Tp __pstl_transform_reduce( |
| 152 | + __cpu_backend_tag, |
| 153 | + _ForwardIterator __first, |
| 154 | + _ForwardIterator __last, |
| 155 | + _Tp __init, |
| 156 | + _BinaryOperation __reduce, |
| 157 | + _UnaryOperation __transform) { |
| 158 | + if constexpr (__is_parallel_execution_policy_v<_ExecutionPolicy> && |
| 159 | + __has_random_access_iterator_category<_ForwardIterator>::value) { |
| 160 | + return std::__terminate_on_exception([&] { |
| 161 | + return __par_backend::__parallel_transform_reduce( |
| 162 | + std::move(__first), |
| 163 | + std::move(__last), |
| 164 | + [__transform](_ForwardIterator __iter) { return __transform(*__iter); }, |
| 165 | + std::move(__init), |
| 166 | + std::move(__reduce), |
| 167 | + [=](_ForwardIterator __brick_first, _ForwardIterator __brick_last, _Tp __brick_init) { |
| 168 | + return std::__pstl_transform_reduce<__remove_parallel_policy_t<_ExecutionPolicy>>( |
| 169 | + __cpu_backend_tag{}, |
| 170 | + std::move(__brick_first), |
| 171 | + std::move(__brick_last), |
| 172 | + std::move(__brick_init), |
| 173 | + std::move(__reduce), |
| 174 | + std::move(__transform)); |
| 175 | + }); |
| 176 | + }); |
| 177 | + } else if constexpr (__is_unsequenced_execution_policy_v<_ExecutionPolicy> && |
| 178 | + __has_random_access_iterator_category<_ForwardIterator>::value) { |
| 179 | + return std::__simd_transform_reduce( |
| 180 | + __last - __first, |
| 181 | + std::move(__init), |
| 182 | + std::move(__reduce), |
| 183 | + [=, &__transform](__iter_diff_t<_ForwardIterator> __i) { return __transform(__first[__i]); }); |
| 184 | + } else { |
| 185 | + return std::transform_reduce( |
| 186 | + std::move(__first), std::move(__last), std::move(__init), std::move(__reduce), std::move(__transform)); |
| 187 | + } |
| 188 | +} |
| 189 | + |
| 190 | +_LIBCPP_END_NAMESPACE_STD |
| 191 | + |
| 192 | +#endif // !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17 |
| 193 | + |
| 194 | +#endif // _LIBCPP___ALGORITHM_PSTL_BACKENDS_CPU_BACKENDS_TRANSFORM_REDUCE_H |
0 commit comments