Skip to content

[libc++] Vectorize std::mismatch with trivially equality comparable types #87716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

philnik777
Copy link
Contributor

No description provided.

Copy link

github-actions bot commented Apr 4, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@philnik777 philnik777 force-pushed the mismatch_vectorize_trivially_equality_comparable branch from 57dfddd to dc09211 Compare April 5, 2024 09:46
@philnik777 philnik777 changed the title [libc++] Vectorize trivially equality comparable types [libc++] Vectorize std::mismatch with trivially equality comparable types Apr 5, 2024
@philnik777 philnik777 force-pushed the mismatch_vectorize_trivially_equality_comparable branch 6 times, most recently from 7c667e6 to 5346884 Compare April 12, 2024 10:34
Copy link
Member

@ldionne ldionne left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM on a high level, but I have some comments and I'd like to see this again before it ships, since it's tricky. But the approach looks fine to me. This is clever!

@ldionne ldionne marked this pull request as ready for review April 12, 2024 15:47
@ldionne ldionne requested a review from a team as a code owner April 12, 2024 15:47
@llvmbot llvmbot added the libc++ libc++ C++ Standard Library. Not GNU libstdc++. Not libc++abi. label Apr 12, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 12, 2024

@llvm/pr-subscribers-libcxx

Author: Nikolas Klauser (philnik777)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/87716.diff

7 Files Affected:

  • (modified) libcxx/include/CMakeLists.txt (+1)
  • (modified) libcxx/include/__algorithm/mismatch.h (+40-12)
  • (modified) libcxx/include/__algorithm/simd_utils.h (+31-3)
  • (added) libcxx/include/__iterator/aliasing_iterator.h (+120)
  • (modified) libcxx/include/libcxx.imp (+1)
  • (modified) libcxx/include/module.modulemap (+1)
  • (modified) libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp (+37-10)
diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt
index 097a41d4c41740..5922c9106853ed 100644
--- a/libcxx/include/CMakeLists.txt
+++ b/libcxx/include/CMakeLists.txt
@@ -456,6 +456,7 @@ set(files
   __ios/fpos.h
   __iterator/access.h
   __iterator/advance.h
+  __iterator/aliasing_iterator.h
   __iterator/back_insert_iterator.h
   __iterator/bounded_iter.h
   __iterator/common_iterator.h
diff --git a/libcxx/include/__algorithm/mismatch.h b/libcxx/include/__algorithm/mismatch.h
index 4ada29eabc470c..c25fae67c7cfb3 100644
--- a/libcxx/include/__algorithm/mismatch.h
+++ b/libcxx/include/__algorithm/mismatch.h
@@ -16,6 +16,7 @@
 #include <__algorithm/unwrap_iter.h>
 #include <__config>
 #include <__functional/identity.h>
+#include <__iterator/aliasing_iterator.h>
 #include <__type_traits/desugars_to.h>
 #include <__type_traits/invoke.h>
 #include <__type_traits/is_constant_evaluated.h>
@@ -55,18 +56,13 @@ __mismatch(_Iter1 __first1, _Sent1 __last1, _Iter2 __first2, _Pred& __pred, _Pro
 
 #if _LIBCPP_VECTORIZE_ALGORITHMS
 
-template <class _Tp,
-          class _Pred,
-          class _Proj1,
-          class _Proj2,
-          __enable_if_t<is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp> &&
-                            __is_identity<_Proj1>::value && __is_identity<_Proj2>::value,
-                        int> = 0>
-_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
-__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
+template <class _Iter>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Iter, _Iter>
+__mismatch_vectorized(_Iter __first1, _Iter __last1, _Iter __first2) {
+  using __value_type              = __iter_value_type<_Iter>;
   constexpr size_t __unroll_count = 4;
-  constexpr size_t __vec_size     = __native_vector_size<_Tp>;
-  using __vec                     = __simd_vector<_Tp, __vec_size>;
+  constexpr size_t __vec_size     = __native_vector_size<__value_type>;
+  using __vec                     = __simd_vector<__value_type, __vec_size>;
 
   if (!__libcpp_is_constant_evaluated()) {
     auto __orig_first1 = __first1;
@@ -116,9 +112,41 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __
     } // else loop over the elements individually
   }
 
-  return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
+  __equal_to __pred;
+  __identity __proj;
+  return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj, __proj);
+}
+
+template <class _Tp,
+          class _Pred,
+          class _Proj1,
+          class _Proj2,
+          __enable_if_t<is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp> &&
+                            __is_identity<_Proj1>::value && __is_identity<_Proj2>::value,
+                        int> = 0>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
+__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred&, _Proj1&, _Proj2&) {
+  return std::__mismatch_vectorized(__first1, __last1, __first2);
 }
 
+template <class _Tp,
+          class _Pred,
+          class _Proj1,
+          class _Proj2,
+          __enable_if_t<!is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp> &&
+                            __is_identity<_Proj1>::value && __is_identity<_Proj2>::value &&
+                            __can_map_to_integer_v<_Tp> && __libcpp_is_trivially_equality_comparable<_Tp, _Tp>::value,
+                        int> = 0>
+_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
+__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
+  if (__libcpp_is_constant_evaluated()) {
+    return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
+  } else {
+    using _Iter = __aliasing_iterator<_Tp, __get_as_integer_type_t<_Tp>>;
+    auto __ret  = std::__mismatch_vectorized(_Iter(__first1), _Iter(__last1), _Iter(__first2));
+    return {__ret.first.base(), __ret.second.base()};
+  }
+}
 #endif // _LIBCPP_VECTORIZE_ALGORITHMS
 
 template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
diff --git a/libcxx/include/__algorithm/simd_utils.h b/libcxx/include/__algorithm/simd_utils.h
index 989a1957987e1e..7fd7b3c33ee6de 100644
--- a/libcxx/include/__algorithm/simd_utils.h
+++ b/libcxx/include/__algorithm/simd_utils.h
@@ -43,6 +43,34 @@ _LIBCPP_PUSH_MACROS
 
 _LIBCPP_BEGIN_NAMESPACE_STD
 
+template <class _Tp>
+inline constexpr bool __can_map_to_integer_v =
+    sizeof(_Tp) == alignof(_Tp) && (sizeof(_Tp) == 1 || sizeof(_Tp) == 2 || sizeof(_Tp) == 4 || sizeof(_Tp) == 8);
+
+template <size_t _TypeSize>
+struct __get_as_integer_type_impl;
+
+template <>
+struct __get_as_integer_type_impl<1> {
+  using type = uint8_t;
+};
+
+template <>
+struct __get_as_integer_type_impl<2> {
+  using type = uint16_t;
+};
+template <>
+struct __get_as_integer_type_impl<4> {
+  using type = uint32_t;
+};
+template <>
+struct __get_as_integer_type_impl<8> {
+  using type = uint64_t;
+};
+
+template <class _Tp>
+using __get_as_integer_type_t = typename __get_as_integer_type_impl<sizeof(_Tp)>::type;
+
 // This isn't specialized for 64 byte vectors on purpose. They have the potential to significantly reduce performance
 // in mixed simd/non-simd workloads and don't provide any performance improvement for currently vectorized algorithms
 // as far as benchmarks are concerned.
@@ -80,10 +108,10 @@ template <class _VecT>
 using __simd_vector_underlying_type_t = decltype(std::__simd_vector_underlying_type_impl(_VecT{}));
 
 // This isn't inlined without always_inline when loading chars.
-template <class _VecT, class _Tp>
-_LIBCPP_NODISCARD _LIBCPP_ALWAYS_INLINE _LIBCPP_HIDE_FROM_ABI _VecT __load_vector(const _Tp* __ptr) noexcept {
+template <class _VecT, class _Iter>
+_LIBCPP_NODISCARD _LIBCPP_ALWAYS_INLINE _LIBCPP_HIDE_FROM_ABI _VecT __load_vector(_Iter __iter) noexcept {
   return [=]<size_t... _Indices>(index_sequence<_Indices...>) _LIBCPP_ALWAYS_INLINE noexcept {
-    return _VecT{__ptr[_Indices]...};
+    return _VecT{__iter[_Indices]...};
   }(make_index_sequence<__simd_vector_size_v<_VecT>>{});
 }
 
diff --git a/libcxx/include/__iterator/aliasing_iterator.h b/libcxx/include/__iterator/aliasing_iterator.h
new file mode 100644
index 00000000000000..eb2b59af188f7b
--- /dev/null
+++ b/libcxx/include/__iterator/aliasing_iterator.h
@@ -0,0 +1,120 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___ITERATOR_ALIASING_ITERATOR_H
+#define _LIBCPP___ITERATOR_ALIASING_ITERATOR_H
+
+#include <__compare/strong_order.h>
+#include <__config>
+#include <__iterator/concepts.h>
+#include <__iterator/iterator_traits.h>
+#include <__type_traits/is_trivial.h>
+#include <cstddef>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+#  pragma GCC system_header
+#endif
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+template <class _BaseT, class _Alias>
+struct __aliasing_iterator_wrapper {
+  class __iterator {
+    _BaseT* __base_ = nullptr;
+
+  public:
+    using iterator_category = random_access_iterator_tag;
+    using value_type        = _Alias;
+    using difference_type   = ptrdiff_t;
+
+    static_assert(is_trivial<_Alias>::value);
+    static_assert(sizeof(_BaseT) == sizeof(_Alias));
+
+    _LIBCPP_HIDE_FROM_ABI __iterator() = default;
+    _LIBCPP_HIDE_FROM_ABI __iterator(_BaseT* __base) _NOEXCEPT : __base_(__base) {}
+
+    _LIBCPP_HIDE_FROM_ABI __iterator& operator++() _NOEXCEPT {
+      ++__base_;
+      return *this;
+    }
+
+    _LIBCPP_HIDE_FROM_ABI __iterator operator++(int) _NOEXCEPT {
+      __iterator __tmp(*this);
+      ++__base_;
+      return __tmp;
+    }
+
+    _LIBCPP_HIDE_FROM_ABI __iterator& operator--() _NOEXCEPT {
+      --__base_;
+      return *this;
+    }
+
+    _LIBCPP_HIDE_FROM_ABI __iterator operator--(int) _NOEXCEPT {
+      __iterator __tmp(*this);
+      --__base_;
+      return __tmp;
+    }
+
+    _LIBCPP_HIDE_FROM_ABI friend __iterator operator+(__iterator __iter, difference_type __n) _NOEXCEPT {
+      return __iterator(__iter.__base_ + __n);
+    }
+
+    _LIBCPP_HIDE_FROM_ABI friend __iterator operator+(difference_type __n, __iterator __iter) _NOEXCEPT {
+      return __iterator(__n + __iter.__base_);
+    }
+
+    _LIBCPP_HIDE_FROM_ABI __iterator& operator+=(difference_type __n) _NOEXCEPT {
+      __base_ += __n;
+      return *this;
+    }
+
+    _LIBCPP_HIDE_FROM_ABI friend __iterator operator-(__iterator __iter, difference_type __n) _NOEXCEPT {
+      return __iterator(__iter.__base_ - __n);
+    }
+
+    _LIBCPP_HIDE_FROM_ABI friend difference_type operator-(__iterator __lhs, __iterator __rhs) _NOEXCEPT {
+      return __lhs.__base_ - __rhs.__base_;
+    }
+
+    _LIBCPP_HIDE_FROM_ABI __iterator& operator-=(difference_type __n) _NOEXCEPT {
+      __base_ -= __n;
+      return *this;
+    }
+
+    _LIBCPP_HIDE_FROM_ABI _BaseT* base() _NOEXCEPT { return __base_; }
+    _LIBCPP_HIDE_FROM_ABI const _BaseT* base() const _NOEXCEPT { return __base_; }
+
+    _LIBCPP_HIDE_FROM_ABI _Alias operator*() const _NOEXCEPT {
+      _Alias __val;
+      __builtin_memcpy(&__val, __base_, sizeof(_BaseT));
+      return __val;
+    }
+
+    _LIBCPP_HIDE_FROM_ABI _Alias operator[](difference_type __n) const _NOEXCEPT { return *(*this + __n); }
+
+    _LIBCPP_HIDE_FROM_ABI friend bool operator==(const __iterator& __lhs, const __iterator& __rhs) _NOEXCEPT {
+      return __lhs.__base_ == __rhs.__base_;
+    }
+
+    _LIBCPP_HIDE_FROM_ABI friend bool operator!=(const __iterator& __lhs, const __iterator& __rhs) _NOEXCEPT {
+      return __lhs.__base_ != __rhs.__base_;
+    }
+
+#if _LIBCPP_STD_VER >= 20
+    friend auto operator<=>(__iterator, __iterator) noexcept = default;
+#endif
+  };
+};
+
+// This is required to avoid ADL instantiations on _BaseT
+template <class _BaseT, class _Alias>
+using __aliasing_iterator = __aliasing_iterator_wrapper<_BaseT, _Alias>::__iterator;
+
+_LIBCPP_END_NAMESPACE_STD
+
+#endif // _LIBCPP___ITERATOR_ALIASING_ITERATOR_H
diff --git a/libcxx/include/libcxx.imp b/libcxx/include/libcxx.imp
index 607f63e6d82206..e958756004273e 100644
--- a/libcxx/include/libcxx.imp
+++ b/libcxx/include/libcxx.imp
@@ -450,6 +450,7 @@
   { include: [ "<__ios/fpos.h>", "private", "<ios>", "public" ] },
   { include: [ "<__iterator/access.h>", "private", "<iterator>", "public" ] },
   { include: [ "<__iterator/advance.h>", "private", "<iterator>", "public" ] },
+  { include: [ "<__iterator/aliasing_iterator.h>", "private", "<iterator>", "public" ] },
   { include: [ "<__iterator/back_insert_iterator.h>", "private", "<iterator>", "public" ] },
   { include: [ "<__iterator/bounded_iter.h>", "private", "<iterator>", "public" ] },
   { include: [ "<__iterator/common_iterator.h>", "private", "<iterator>", "public" ] },
diff --git a/libcxx/include/module.modulemap b/libcxx/include/module.modulemap
index ed45a1b1833893..d2ce35e3fe481e 100644
--- a/libcxx/include/module.modulemap
+++ b/libcxx/include/module.modulemap
@@ -1406,6 +1406,7 @@ module std_private_iosfwd_streambuf_fwd [system] { header "__fwd/streambuf.h" }
 
 module std_private_iterator_access                  [system] { header "__iterator/access.h" }
 module std_private_iterator_advance                 [system] { header "__iterator/advance.h" }
+module std_private_iterator_aliasing_iterator       [system] { header "__iterator/aliasing_iterator.h" }
 module std_private_iterator_back_insert_iterator    [system] { header "__iterator/back_insert_iterator.h" }
 module std_private_iterator_bounded_iter            [system] { header "__iterator/bounded_iter.h" }
 module std_private_iterator_common_iterator         [system] { header "__iterator/common_iterator.h" }
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
index eb5f7cacdde34b..72df17628dad78 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp
@@ -66,14 +66,27 @@ TEST_CONSTEXPR_CXX20 void check(Container1 lhs, Container2 rhs, size_t offset) {
 #endif
 }
 
-struct NonTrivial {
+// Compares modulo 4 to make sure we only forward to the vectorized version if we are trivially equality comparable
+struct NonTrivialMod4Comp {
   int i_;
 
-  TEST_CONSTEXPR_CXX20 NonTrivial(int i) : i_(i) {}
-  TEST_CONSTEXPR_CXX20 NonTrivial(NonTrivial&& other) : i_(other.i_) { other.i_ = 0; }
+  TEST_CONSTEXPR_CXX20 NonTrivialMod4Comp(int i) : i_(i) {}
+  TEST_CONSTEXPR_CXX20 NonTrivialMod4Comp(NonTrivialMod4Comp&& other) : i_(other.i_) { other.i_ = 0; }
 
-  TEST_CONSTEXPR_CXX20 friend bool operator==(const NonTrivial& lhs, const NonTrivial& rhs) { return lhs.i_ == rhs.i_; }
+  TEST_CONSTEXPR_CXX20 friend bool operator==(const NonTrivialMod4Comp& lhs, const NonTrivialMod4Comp& rhs) {
+    return lhs.i_ % 4 == rhs.i_ % 4;
+  }
+};
+
+#if TEST_STD_VER >= 20
+struct TriviallyEqualityComparable {
+  int i_;
+
+  TEST_CONSTEXPR_CXX20 TriviallyEqualityComparable(int i) : i_(i) {}
+
+  TEST_CONSTEXPR_CXX20 friend bool operator==(TriviallyEqualityComparable, TriviallyEqualityComparable) = default;
 };
+#endif // TEST_STD_VER >= 20
 
 struct ModTwoComp {
   TEST_CONSTEXPR_CXX20 bool operator()(int lhs, int rhs) { return lhs % 2 == rhs % 2; }
@@ -136,16 +149,30 @@ TEST_CONSTEXPR_CXX20 bool test() {
   types::for_each(types::cpp17_input_iterator_list<int*>(), Test());
 
   { // use a non-integer type to also test the general case - all elements match
-    std::array<NonTrivial, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
-    std::array<NonTrivial, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
-    check<NonTrivial*>(std::move(lhs), std::move(rhs), 8);
+    std::array<NonTrivialMod4Comp, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
+    std::array<NonTrivialMod4Comp, 8> rhs = {1, 2, 3, 4, 1, 6, 7, 8};
+    check<NonTrivialMod4Comp*>(std::move(lhs), std::move(rhs), 8);
   }
 
   { // use a non-integer type to also test the general case - not all elements match
-    std::array<NonTrivial, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
-    std::array<NonTrivial, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
-    check<NonTrivial*>(std::move(lhs), std::move(rhs), 4);
+    std::array<NonTrivialMod4Comp, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
+    std::array<NonTrivialMod4Comp, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
+    check<NonTrivialMod4Comp*>(std::move(lhs), std::move(rhs), 4);
+  }
+
+#if TEST_STD_VER >= 20
+  { // trivially equaltiy comparable class type to test forwarding to the vectorized version - all elements match
+    std::array<TriviallyEqualityComparable, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
+    std::array<TriviallyEqualityComparable, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
+    check<TriviallyEqualityComparable*>(std::move(lhs), std::move(rhs), 8);
+  }
+
+  { // trivially equaltiy comparable class type to test forwarding to the vectorized version - not all elements match
+    std::array<TriviallyEqualityComparable, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
+    std::array<TriviallyEqualityComparable, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
+    check<TriviallyEqualityComparable*>(std::move(lhs), std::move(rhs), 4);
   }
+#endif // TEST_STD_VER >= 20
 
   return true;
 }

@philnik777 philnik777 force-pushed the mismatch_vectorize_trivially_equality_comparable branch 2 times, most recently from 6034ac5 to 8bf6114 Compare April 14, 2024 07:26
@ldionne ldionne self-assigned this Apr 15, 2024
Copy link
Member

@ldionne ldionne left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM w/ nitpicks and the CI passing.

@philnik777 philnik777 force-pushed the mismatch_vectorize_trivially_equality_comparable branch 2 times, most recently from dc3aa12 to 5a7d78b Compare May 10, 2024 09:47
@philnik777 philnik777 force-pushed the mismatch_vectorize_trivially_equality_comparable branch from 5a7d78b to 2ce6413 Compare May 11, 2024 18:28
@philnik777 philnik777 merged commit 05cc2d5 into llvm:main May 11, 2024
51 checks passed
@philnik777 philnik777 deleted the mismatch_vectorize_trivially_equality_comparable branch May 11, 2024 21:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
libc++ libc++ C++ Standard Library. Not GNU libstdc++. Not libc++abi.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants