-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[libc++] Vectorize std::mismatch with trivially equality comparable types #87716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[libc++] Vectorize std::mismatch with trivially equality comparable types #87716
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
57dfddd
to
dc09211
Compare
7c667e6
to
5346884
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM on a high level, but I have some comments and I'd like to see this again before it ships, since it's tricky. But the approach looks fine to me. This is clever!
libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
Outdated
Show resolved
Hide resolved
libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
Outdated
Show resolved
Hide resolved
@llvm/pr-subscribers-libcxx Author: Nikolas Klauser (philnik777) ChangesFull diff: https://github.com/llvm/llvm-project/pull/87716.diff 7 Files Affected:
diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt
index 097a41d4c41740..5922c9106853ed 100644
--- a/libcxx/include/CMakeLists.txt
+++ b/libcxx/include/CMakeLists.txt
@@ -456,6 +456,7 @@ set(files
__ios/fpos.h
__iterator/access.h
__iterator/advance.h
+ __iterator/aliasing_iterator.h
__iterator/back_insert_iterator.h
__iterator/bounded_iter.h
__iterator/common_iterator.h
diff --git a/libcxx/include/__algorithm/mismatch.h b/libcxx/include/__algorithm/mismatch.h
index 4ada29eabc470c..c25fae67c7cfb3 100644
--- a/libcxx/include/__algorithm/mismatch.h
+++ b/libcxx/include/__algorithm/mismatch.h
@@ -16,6 +16,7 @@
#include <__algorithm/unwrap_iter.h>
#include <__config>
#include <__functional/identity.h>
+#include <__iterator/aliasing_iterator.h>
#include <__type_traits/desugars_to.h>
#include <__type_traits/invoke.h>
#include <__type_traits/is_constant_evaluated.h>
@@ -55,18 +56,13 @@ __mismatch(_Iter1 __first1, _Sent1 __last1, _Iter2 __first2, _Pred& __pred, _Pro
#if _LIBCPP_VECTORIZE_ALGORITHMS
-template <class _Tp,
- class _Pred,
- class _Proj1,
- class _Proj2,
- __enable_if_t<is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp> &&
- __is_identity<_Proj1>::value && __is_identity<_Proj2>::value,
- int> = 0>
-_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
-__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
+template <class _Iter>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Iter, _Iter>
+__mismatch_vectorized(_Iter __first1, _Iter __last1, _Iter __first2) {
+ using __value_type = __iter_value_type<_Iter>;
constexpr size_t __unroll_count = 4;
- constexpr size_t __vec_size = __native_vector_size<_Tp>;
- using __vec = __simd_vector<_Tp, __vec_size>;
+ constexpr size_t __vec_size = __native_vector_size<__value_type>;
+ using __vec = __simd_vector<__value_type, __vec_size>;
if (!__libcpp_is_constant_evaluated()) {
auto __orig_first1 = __first1;
@@ -116,9 +112,41 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __
} // else loop over the elements individually
}
- return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
+ __equal_to __pred;
+ __identity __proj;
+ return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj, __proj);
+}
+
+template <class _Tp,
+ class _Pred,
+ class _Proj1,
+ class _Proj2,
+ __enable_if_t<is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp> &&
+ __is_identity<_Proj1>::value && __is_identity<_Proj2>::value,
+ int> = 0>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
+__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred&, _Proj1&, _Proj2&) {
+ return std::__mismatch_vectorized(__first1, __last1, __first2);
}
+template <class _Tp,
+ class _Pred,
+ class _Proj1,
+ class _Proj2,
+ __enable_if_t<!is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp> &&
+ __is_identity<_Proj1>::value && __is_identity<_Proj2>::value &&
+ __can_map_to_integer_v<_Tp> && __libcpp_is_trivially_equality_comparable<_Tp, _Tp>::value,
+ int> = 0>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
+__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
+ if (__libcpp_is_constant_evaluated()) {
+ return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
+ } else {
+ using _Iter = __aliasing_iterator<_Tp, __get_as_integer_type_t<_Tp>>;
+ auto __ret = std::__mismatch_vectorized(_Iter(__first1), _Iter(__last1), _Iter(__first2));
+ return {__ret.first.base(), __ret.second.base()};
+ }
+}
#endif // _LIBCPP_VECTORIZE_ALGORITHMS
template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
diff --git a/libcxx/include/__algorithm/simd_utils.h b/libcxx/include/__algorithm/simd_utils.h
index 989a1957987e1e..7fd7b3c33ee6de 100644
--- a/libcxx/include/__algorithm/simd_utils.h
+++ b/libcxx/include/__algorithm/simd_utils.h
@@ -43,6 +43,34 @@ _LIBCPP_PUSH_MACROS
_LIBCPP_BEGIN_NAMESPACE_STD
+template <class _Tp>
+inline constexpr bool __can_map_to_integer_v =
+ sizeof(_Tp) == alignof(_Tp) && (sizeof(_Tp) == 1 || sizeof(_Tp) == 2 || sizeof(_Tp) == 4 || sizeof(_Tp) == 8);
+
+template <size_t _TypeSize>
+struct __get_as_integer_type_impl;
+
+template <>
+struct __get_as_integer_type_impl<1> {
+ using type = uint8_t;
+};
+
+template <>
+struct __get_as_integer_type_impl<2> {
+ using type = uint16_t;
+};
+template <>
+struct __get_as_integer_type_impl<4> {
+ using type = uint32_t;
+};
+template <>
+struct __get_as_integer_type_impl<8> {
+ using type = uint64_t;
+};
+
+template <class _Tp>
+using __get_as_integer_type_t = typename __get_as_integer_type_impl<sizeof(_Tp)>::type;
+
// This isn't specialized for 64 byte vectors on purpose. They have the potential to significantly reduce performance
// in mixed simd/non-simd workloads and don't provide any performance improvement for currently vectorized algorithms
// as far as benchmarks are concerned.
@@ -80,10 +108,10 @@ template <class _VecT>
using __simd_vector_underlying_type_t = decltype(std::__simd_vector_underlying_type_impl(_VecT{}));
// This isn't inlined without always_inline when loading chars.
-template <class _VecT, class _Tp>
-_LIBCPP_NODISCARD _LIBCPP_ALWAYS_INLINE _LIBCPP_HIDE_FROM_ABI _VecT __load_vector(const _Tp* __ptr) noexcept {
+template <class _VecT, class _Iter>
+_LIBCPP_NODISCARD _LIBCPP_ALWAYS_INLINE _LIBCPP_HIDE_FROM_ABI _VecT __load_vector(_Iter __iter) noexcept {
return [=]<size_t... _Indices>(index_sequence<_Indices...>) _LIBCPP_ALWAYS_INLINE noexcept {
- return _VecT{__ptr[_Indices]...};
+ return _VecT{__iter[_Indices]...};
}(make_index_sequence<__simd_vector_size_v<_VecT>>{});
}
diff --git a/libcxx/include/__iterator/aliasing_iterator.h b/libcxx/include/__iterator/aliasing_iterator.h
new file mode 100644
index 00000000000000..eb2b59af188f7b
--- /dev/null
+++ b/libcxx/include/__iterator/aliasing_iterator.h
@@ -0,0 +1,120 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___ITERATOR_ALIASING_ITERATOR_H
+#define _LIBCPP___ITERATOR_ALIASING_ITERATOR_H
+
+#include <__compare/strong_order.h>
+#include <__config>
+#include <__iterator/concepts.h>
+#include <__iterator/iterator_traits.h>
+#include <__type_traits/is_trivial.h>
+#include <cstddef>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+# pragma GCC system_header
+#endif
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+template <class _BaseT, class _Alias>
+struct __aliasing_iterator_wrapper {
+ class __iterator {
+ _BaseT* __base_ = nullptr;
+
+ public:
+ using iterator_category = random_access_iterator_tag;
+ using value_type = _Alias;
+ using difference_type = ptrdiff_t;
+
+ static_assert(is_trivial<_Alias>::value);
+ static_assert(sizeof(_BaseT) == sizeof(_Alias));
+
+ _LIBCPP_HIDE_FROM_ABI __iterator() = default;
+ _LIBCPP_HIDE_FROM_ABI __iterator(_BaseT* __base) _NOEXCEPT : __base_(__base) {}
+
+ _LIBCPP_HIDE_FROM_ABI __iterator& operator++() _NOEXCEPT {
+ ++__base_;
+ return *this;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI __iterator operator++(int) _NOEXCEPT {
+ __iterator __tmp(*this);
+ ++__base_;
+ return __tmp;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI __iterator& operator--() _NOEXCEPT {
+ --__base_;
+ return *this;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI __iterator operator--(int) _NOEXCEPT {
+ __iterator __tmp(*this);
+ --__base_;
+ return __tmp;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI friend __iterator operator+(__iterator __iter, difference_type __n) _NOEXCEPT {
+ return __iterator(__iter.__base_ + __n);
+ }
+
+ _LIBCPP_HIDE_FROM_ABI friend __iterator operator+(difference_type __n, __iterator __iter) _NOEXCEPT {
+ return __iterator(__n + __iter.__base_);
+ }
+
+ _LIBCPP_HIDE_FROM_ABI __iterator& operator+=(difference_type __n) _NOEXCEPT {
+ __base_ += __n;
+ return *this;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI friend __iterator operator-(__iterator __iter, difference_type __n) _NOEXCEPT {
+ return __iterator(__iter.__base_ - __n);
+ }
+
+ _LIBCPP_HIDE_FROM_ABI friend difference_type operator-(__iterator __lhs, __iterator __rhs) _NOEXCEPT {
+ return __lhs.__base_ - __rhs.__base_;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI __iterator& operator-=(difference_type __n) _NOEXCEPT {
+ __base_ -= __n;
+ return *this;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI _BaseT* base() _NOEXCEPT { return __base_; }
+ _LIBCPP_HIDE_FROM_ABI const _BaseT* base() const _NOEXCEPT { return __base_; }
+
+ _LIBCPP_HIDE_FROM_ABI _Alias operator*() const _NOEXCEPT {
+ _Alias __val;
+ __builtin_memcpy(&__val, __base_, sizeof(_BaseT));
+ return __val;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI _Alias operator[](difference_type __n) const _NOEXCEPT { return *(*this + __n); }
+
+ _LIBCPP_HIDE_FROM_ABI friend bool operator==(const __iterator& __lhs, const __iterator& __rhs) _NOEXCEPT {
+ return __lhs.__base_ == __rhs.__base_;
+ }
+
+ _LIBCPP_HIDE_FROM_ABI friend bool operator!=(const __iterator& __lhs, const __iterator& __rhs) _NOEXCEPT {
+ return __lhs.__base_ != __rhs.__base_;
+ }
+
+#if _LIBCPP_STD_VER >= 20
+ friend auto operator<=>(__iterator, __iterator) noexcept = default;
+#endif
+ };
+};
+
+// This is required to avoid ADL instantiations on _BaseT
+template <class _BaseT, class _Alias>
+using __aliasing_iterator = __aliasing_iterator_wrapper<_BaseT, _Alias>::__iterator;
+
+_LIBCPP_END_NAMESPACE_STD
+
+#endif // _LIBCPP___ITERATOR_ALIASING_ITERATOR_H
diff --git a/libcxx/include/libcxx.imp b/libcxx/include/libcxx.imp
index 607f63e6d82206..e958756004273e 100644
--- a/libcxx/include/libcxx.imp
+++ b/libcxx/include/libcxx.imp
@@ -450,6 +450,7 @@
{ include: [ "<__ios/fpos.h>", "private", "<ios>", "public" ] },
{ include: [ "<__iterator/access.h>", "private", "<iterator>", "public" ] },
{ include: [ "<__iterator/advance.h>", "private", "<iterator>", "public" ] },
+ { include: [ "<__iterator/aliasing_iterator.h>", "private", "<iterator>", "public" ] },
{ include: [ "<__iterator/back_insert_iterator.h>", "private", "<iterator>", "public" ] },
{ include: [ "<__iterator/bounded_iter.h>", "private", "<iterator>", "public" ] },
{ include: [ "<__iterator/common_iterator.h>", "private", "<iterator>", "public" ] },
diff --git a/libcxx/include/module.modulemap b/libcxx/include/module.modulemap
index ed45a1b1833893..d2ce35e3fe481e 100644
--- a/libcxx/include/module.modulemap
+++ b/libcxx/include/module.modulemap
@@ -1406,6 +1406,7 @@ module std_private_iosfwd_streambuf_fwd [system] { header "__fwd/streambuf.h" }
module std_private_iterator_access [system] { header "__iterator/access.h" }
module std_private_iterator_advance [system] { header "__iterator/advance.h" }
+module std_private_iterator_aliasing_iterator [system] { header "__iterator/aliasing_iterator.h" }
module std_private_iterator_back_insert_iterator [system] { header "__iterator/back_insert_iterator.h" }
module std_private_iterator_bounded_iter [system] { header "__iterator/bounded_iter.h" }
module std_private_iterator_common_iterator [system] { header "__iterator/common_iterator.h" }
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
index eb5f7cacdde34b..72df17628dad78 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
@@ -66,14 +66,27 @@ TEST_CONSTEXPR_CXX20 void check(Container1 lhs, Container2 rhs, size_t offset) {
#endif
}
-struct NonTrivial {
+// Compares modulo 4 to make sure we only forward to the vectorized version if we are trivially equality comparable
+struct NonTrivialMod4Comp {
int i_;
- TEST_CONSTEXPR_CXX20 NonTrivial(int i) : i_(i) {}
- TEST_CONSTEXPR_CXX20 NonTrivial(NonTrivial&& other) : i_(other.i_) { other.i_ = 0; }
+ TEST_CONSTEXPR_CXX20 NonTrivialMod4Comp(int i) : i_(i) {}
+ TEST_CONSTEXPR_CXX20 NonTrivialMod4Comp(NonTrivialMod4Comp&& other) : i_(other.i_) { other.i_ = 0; }
- TEST_CONSTEXPR_CXX20 friend bool operator==(const NonTrivial& lhs, const NonTrivial& rhs) { return lhs.i_ == rhs.i_; }
+ TEST_CONSTEXPR_CXX20 friend bool operator==(const NonTrivialMod4Comp& lhs, const NonTrivialMod4Comp& rhs) {
+ return lhs.i_ % 4 == rhs.i_ % 4;
+ }
+};
+
+#if TEST_STD_VER >= 20
+struct TriviallyEqualityComparable {
+ int i_;
+
+ TEST_CONSTEXPR_CXX20 TriviallyEqualityComparable(int i) : i_(i) {}
+
+ TEST_CONSTEXPR_CXX20 friend bool operator==(TriviallyEqualityComparable, TriviallyEqualityComparable) = default;
};
+#endif // TEST_STD_VER >= 20
struct ModTwoComp {
TEST_CONSTEXPR_CXX20 bool operator()(int lhs, int rhs) { return lhs % 2 == rhs % 2; }
@@ -136,16 +149,30 @@ TEST_CONSTEXPR_CXX20 bool test() {
types::for_each(types::cpp17_input_iterator_list<int*>(), Test());
{ // use a non-integer type to also test the general case - all elements match
- std::array<NonTrivial, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
- std::array<NonTrivial, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
- check<NonTrivial*>(std::move(lhs), std::move(rhs), 8);
+ std::array<NonTrivialMod4Comp, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
+ std::array<NonTrivialMod4Comp, 8> rhs = {1, 2, 3, 4, 1, 6, 7, 8};
+ check<NonTrivialMod4Comp*>(std::move(lhs), std::move(rhs), 8);
}
{ // use a non-integer type to also test the general case - not all elements match
- std::array<NonTrivial, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
- std::array<NonTrivial, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
- check<NonTrivial*>(std::move(lhs), std::move(rhs), 4);
+ std::array<NonTrivialMod4Comp, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
+ std::array<NonTrivialMod4Comp, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
+ check<NonTrivialMod4Comp*>(std::move(lhs), std::move(rhs), 4);
+ }
+
+#if TEST_STD_VER >= 20
+ { // trivially equaltiy comparable class type to test forwarding to the vectorized version - all elements match
+ std::array<TriviallyEqualityComparable, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
+ std::array<TriviallyEqualityComparable, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
+ check<TriviallyEqualityComparable*>(std::move(lhs), std::move(rhs), 8);
+ }
+
+ { // trivially equaltiy comparable class type to test forwarding to the vectorized version - not all elements match
+ std::array<TriviallyEqualityComparable, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
+ std::array<TriviallyEqualityComparable, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
+ check<TriviallyEqualityComparable*>(std::move(lhs), std::move(rhs), 4);
}
+#endif // TEST_STD_VER >= 20
return true;
}
|
6034ac5
to
8bf6114
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM w/ nitpicks and the CI passing.
dc3aa12
to
5a7d78b
Compare
5a7d78b
to
2ce6413
Compare
No description provided.