Skip to content

[libc++] Optimize ranges::equal for vector<bool>::iterator #121084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 26, 2025

Conversation

winner245
Copy link
Contributor

@winner245 winner245 commented Dec 25, 2024

This PR optimizes the performance of std::ranges::equal for vector<bool>::iterator, addressing a subtask outlined in issue #64038. The optimizations yield performance improvements of up to 188x for aligned equality comparison and 82x for unaligned equality comparison. Moreover, comprehensive tests covering up to 4 storage words (256 bytes) with odd and even bit sizes are provided, which validate the proposed optimizations in this patch.

  • Aligned equality comparison
-----------------------------------------------------------------------------
Benchmark                                   Before        After   Improvement
-----------------------------------------------------------------------------
bm_ranges_equal_vb_aligned/8               11.9 ns      0.973 ns          12x
bm_ranges_equal_vb_aligned/16              23.5 ns      0.985 ns          24x
bm_ranges_equal_vb_aligned/64              98.3 ns       1.16 ns          85x
bm_ranges_equal_vb_aligned/256              384 ns       2.44 ns         157x
bm_ranges_equal_vb_aligned/1024            1540 ns       8.19 ns         188x
bm_ranges_equal_vb_aligned/4096            6166 ns       38.8 ns         159x
bm_ranges_equal_vb_aligned/16384          24942 ns        142 ns         176x
bm_ranges_equal_vb_aligned/65536          99129 ns        549 ns         181x
bm_ranges_equal_vb_aligned/262144        400353 ns       2180 ns         184x
bm_ranges_equal_vb_aligned/1048576      1577081 ns       8811 ns         179x
  • Unaligned equality comparison
-----------------------------------------------------------------------------
Benchmark                                  Before         After   Improvement
-----------------------------------------------------------------------------
bm_ranges_equal_vb_unaligned/8             11.6 ns       6.01 ns         1.9x
bm_ranges_equal_vb_unaligned/64            92.0 ns       5.81 ns          16x
bm_ranges_equal_vb_unaligned/512            735 ns       15.1 ns          49x
bm_ranges_equal_vb_unaligned/4096          5910 ns         87 ns          68x
bm_ranges_equal_vb_unaligned/32768        47209 ns        618 ns          76x
bm_ranges_equal_vb_unaligned/262144      380280 ns       4767 ns          80x
bm_ranges_equal_vb_unaligned/1048576    1572387 ns      19179 ns          82x

@winner245 winner245 force-pushed the optimize-ranges-equal branch 3 times, most recently from 8dbf8de to 40b8a42 Compare December 25, 2024 17:12
Copy link

github-actions bot commented Dec 25, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@winner245 winner245 force-pushed the optimize-ranges-equal branch from 40b8a42 to 098b68d Compare December 25, 2024 17:17
@winner245 winner245 force-pushed the optimize-ranges-equal branch 2 times, most recently from d51edbb to 878808f Compare January 21, 2025 21:38
@winner245 winner245 marked this pull request as ready for review January 22, 2025 13:24
@winner245 winner245 requested a review from a team as a code owner January 22, 2025 13:24
@llvmbot llvmbot added the libc++ libc++ C++ Standard Library. Not GNU libstdc++. Not libc++abi. label Jan 22, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 22, 2025

@llvm/pr-subscribers-libcxx

Author: Peng Liu (winner245)

Changes

This PR optimizes the performance of std::ranges::equal for vector&lt;bool&gt;::iterator, addressing a subtask outlined in issue #64038. The optimizations yield performance improvements of up to 190x for aligned equality comparison and 80x for unaligned equality comparison.

  • Aligned equality comparison
-----------------------------------------------------------------------------
Benchmark                                  Before         After   Improvement
-----------------------------------------------------------------------------
bm_ranges_equal_vb_aligned/8               13.6 ns      0.889 ns          15x
bm_ranges_equal_vb_aligned/64              94.7 ns       1.09 ns          87x
bm_ranges_equal_vb_aligned/512              694 ns       4.15 ns         167x
bm_ranges_equal_vb_aligned/4096            5529 ns       37.4 ns         148x
bm_ranges_equal_vb_aligned/32768          44256 ns        255 ns         173x
bm_ranges_equal_vb_aligned/180224        312311 ns       1695 ns         184x
bm_ranges_equal_vb_aligned/184320        320931 ns       1743 ns         184x
bm_ranges_equal_vb_aligned/188416        325096 ns       1780 ns         183x
bm_ranges_equal_vb_aligned/192512        328834 ns       1806 ns         182x
bm_ranges_equal_vb_aligned/196608        337802 ns       1816 ns         186x
bm_ranges_equal_vb_aligned/200704        360741 ns       1866 ns         193x
bm_ranges_equal_vb_aligned/204800        353518 ns       1915 ns         184x
bm_ranges_equal_vb_aligned/262144        362013 ns       2055 ns         176x
bm_ranges_equal_vb_aligned/1048576      1422572 ns       8406 ns         169x
  • Unaligned equality comparison
-----------------------------------------------------------------------------
Benchmark                                  Before         After   Improvement
-----------------------------------------------------------------------------
bm_ranges_equal_vb_unaligned/8             13.2 ns       5.82 ns         2.3x
bm_ranges_equal_vb_unaligned/64            95.5 ns       5.76 ns          17x
bm_ranges_equal_vb_unaligned/512            717 ns       14.1 ns          51x
bm_ranges_equal_vb_unaligned/4096          5605 ns       80.4 ns          70x
bm_ranges_equal_vb_unaligned/32768        44925 ns        583 ns          77x
bm_ranges_equal_vb_unaligned/262144      360244 ns       4454 ns          81x
bm_ranges_equal_vb_unaligned/1048576    1449077 ns      17869 ns          81x

Patch is 36.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/121084.diff

6 Files Affected:

  • (modified) libcxx/include/__algorithm/equal.h (+159)
  • (modified) libcxx/include/__bit_reference (+11-124)
  • (modified) libcxx/include/bitset (+1)
  • (modified) libcxx/test/benchmarks/algorithms/equal.bench.cpp (+51)
  • (modified) libcxx/test/std/algorithms/alg.nonmodifying/alg.equal/equal.pass.cpp (+33)
  • (modified) libcxx/test/std/algorithms/alg.nonmodifying/alg.equal/ranges.equal.pass.cpp (+156-93)
diff --git a/libcxx/include/__algorithm/equal.h b/libcxx/include/__algorithm/equal.h
index a276bb9954c9bb..0f8f1147b193c3 100644
--- a/libcxx/include/__algorithm/equal.h
+++ b/libcxx/include/__algorithm/equal.h
@@ -11,19 +11,27 @@
 #define _LIBCPP___ALGORITHM_EQUAL_H
 
 #include <__algorithm/comp.h>
+#include <__algorithm/min.h>
 #include <__algorithm/unwrap_iter.h>
 #include <__config>
 #include <__functional/identity.h>
+#include <__fwd/bit_reference.h>
 #include <__iterator/distance.h>
 #include <__iterator/iterator_traits.h>
+#include <__memory/pointer_traits.h>
 #include <__string/constexpr_c_functions.h>
 #include <__type_traits/desugars_to.h>
 #include <__type_traits/enable_if.h>
 #include <__type_traits/invoke.h>
 #include <__type_traits/is_equality_comparable.h>
+#include <__type_traits/is_same.h>
 #include <__type_traits/is_volatile.h>
 #include <__utility/move.h>
 
+#if _LIBCPP_STD_VER >= 20
+#  include <__functional/ranges_operations.h>
+#endif
+
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
 #  pragma GCC system_header
 #endif
@@ -33,6 +41,132 @@ _LIBCPP_PUSH_MACROS
 
 _LIBCPP_BEGIN_NAMESPACE_STD
 
+template <class _Cp, bool _IC1, bool _IC2>
+[[__nodiscard__]] _LIBCPP_CONSTEXPR_SINCE_CXX20 _LIBCPP_HIDE_FROM_ABI bool __equal_unaligned(
+    __bit_iterator<_Cp, _IC1> __first1, __bit_iterator<_Cp, _IC1> __last1, __bit_iterator<_Cp, _IC2> __first2) {
+  using _It             = __bit_iterator<_Cp, _IC1>;
+  using difference_type = typename _It::difference_type;
+  using __storage_type  = typename _It::__storage_type;
+
+  const int __bits_per_word = _It::__bits_per_word;
+  difference_type __n       = __last1 - __first1;
+  if (__n > 0) {
+    // do first word
+    if (__first1.__ctz_ != 0) {
+      unsigned __clz_f     = __bits_per_word - __first1.__ctz_;
+      difference_type __dn = std::min(static_cast<difference_type>(__clz_f), __n);
+      __n -= __dn;
+      __storage_type __m   = (~__storage_type(0) << __first1.__ctz_) & (~__storage_type(0) >> (__clz_f - __dn));
+      __storage_type __b   = *__first1.__seg_ & __m;
+      unsigned __clz_r     = __bits_per_word - __first2.__ctz_;
+      __storage_type __ddn = std::min<__storage_type>(__dn, __clz_r);
+      __m                  = (~__storage_type(0) << __first2.__ctz_) & (~__storage_type(0) >> (__clz_r - __ddn));
+      if (__first2.__ctz_ > __first1.__ctz_) {
+        if ((*__first2.__seg_ & __m) != (__b << (__first2.__ctz_ - __first1.__ctz_)))
+          return false;
+      } else {
+        if ((*__first2.__seg_ & __m) != (__b >> (__first1.__ctz_ - __first2.__ctz_)))
+          return false;
+      }
+      __first2.__seg_ += (__ddn + __first2.__ctz_) / __bits_per_word;
+      __first2.__ctz_ = static_cast<unsigned>((__ddn + __first2.__ctz_) % __bits_per_word);
+      __dn -= __ddn;
+      if (__dn > 0) {
+        __m = ~__storage_type(0) >> (__bits_per_word - __dn);
+        if ((*__first2.__seg_ & __m) != (__b >> (__first1.__ctz_ + __ddn)))
+          return false;
+        __first2.__ctz_ = static_cast<unsigned>(__dn);
+      }
+      ++__first1.__seg_;
+      // __first1.__ctz_ = 0;
+    }
+    // __first1.__ctz_ == 0;
+    // do middle words
+    unsigned __clz_r   = __bits_per_word - __first2.__ctz_;
+    __storage_type __m = ~__storage_type(0) << __first2.__ctz_;
+    for (; __n >= __bits_per_word; __n -= __bits_per_word, ++__first1.__seg_) {
+      __storage_type __b = *__first1.__seg_;
+      if ((*__first2.__seg_ & __m) != (__b << __first2.__ctz_))
+        return false;
+      ++__first2.__seg_;
+      if ((*__first2.__seg_ & ~__m) != (__b >> __clz_r))
+        return false;
+    }
+    // do last word
+    if (__n > 0) {
+      __m                 = ~__storage_type(0) >> (__bits_per_word - __n);
+      __storage_type __b  = *__first1.__seg_ & __m;
+      __storage_type __dn = std::min(__n, static_cast<difference_type>(__clz_r));
+      __m                 = (~__storage_type(0) << __first2.__ctz_) & (~__storage_type(0) >> (__clz_r - __dn));
+      if ((*__first2.__seg_ & __m) != (__b << __first2.__ctz_))
+        return false;
+      __first2.__seg_ += (__dn + __first2.__ctz_) / __bits_per_word;
+      __first2.__ctz_ = static_cast<unsigned>((__dn + __first2.__ctz_) % __bits_per_word);
+      __n -= __dn;
+      if (__n > 0) {
+        __m = ~__storage_type(0) >> (__bits_per_word - __n);
+        if ((*__first2.__seg_ & __m) != (__b >> __dn))
+          return false;
+      }
+    }
+  }
+  return true;
+}
+
+template <class _Cp, bool _IC1, bool _IC2>
+[[__nodiscard__]] _LIBCPP_CONSTEXPR_SINCE_CXX20 _LIBCPP_HIDE_FROM_ABI bool __equal_aligned(
+    __bit_iterator<_Cp, _IC1> __first1, __bit_iterator<_Cp, _IC1> __last1, __bit_iterator<_Cp, _IC2> __first2) {
+  using _It             = __bit_iterator<_Cp, _IC1>;
+  using difference_type = typename _It::difference_type;
+  using __storage_type  = typename _It::__storage_type;
+
+  const int __bits_per_word = _It::__bits_per_word;
+  difference_type __n       = __last1 - __first1;
+  if (__n > 0) {
+    // do first word
+    if (__first1.__ctz_ != 0) {
+      unsigned __clz       = __bits_per_word - __first1.__ctz_;
+      difference_type __dn = std::min(static_cast<difference_type>(__clz), __n);
+      __n -= __dn;
+      __storage_type __m = (~__storage_type(0) << __first1.__ctz_) & (~__storage_type(0) >> (__clz - __dn));
+      if ((*__first2.__seg_ & __m) != (*__first1.__seg_ & __m))
+        return false;
+      ++__first2.__seg_;
+      ++__first1.__seg_;
+      // __first1.__ctz_ = 0;
+      // __first2.__ctz_ = 0;
+    }
+    // __first1.__ctz_ == 0;
+    // __first2.__ctz_ == 0;
+    // do middle words
+    for (; __n >= __bits_per_word; __n -= __bits_per_word, ++__first1.__seg_, ++__first2.__seg_)
+      if (*__first2.__seg_ != *__first1.__seg_)
+        return false;
+    // do last word
+    if (__n > 0) {
+      __storage_type __m = ~__storage_type(0) >> (__bits_per_word - __n);
+      if ((*__first2.__seg_ & __m) != (*__first1.__seg_ & __m))
+        return false;
+    }
+  }
+  return true;
+}
+
+template <class _Cp,
+          bool _IC1,
+          bool _IC2,
+          class _BinaryPredicate,
+          __enable_if_t<std::is_same<_BinaryPredicate, __equal_to>::value, int> = 0>
+[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_iter_impl(
+    __bit_iterator<_Cp, _IC1> __first1,
+    __bit_iterator<_Cp, _IC1> __last1,
+    __bit_iterator<_Cp, _IC2> __first2,
+    _BinaryPredicate) {
+  if (__first1.__ctz_ == __first2.__ctz_)
+    return std::__equal_aligned(__first1, __last1, __first2);
+  return std::__equal_unaligned(__first1, __last1, __first2);
+}
+
 template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
 [[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_iter_impl(
     _InputIterator1 __first1, _InputIterator1 __last1, _InputIterator2 __first2, _BinaryPredicate& __pred) {
@@ -94,6 +228,31 @@ __equal_impl(_Tp* __first1, _Tp* __last1, _Up* __first2, _Up*, _Pred&, _Proj1&,
   return std::__constexpr_memcmp_equal(__first1, __first2, __element_count(__last1 - __first1));
 }
 
+template <class _Cp,
+          bool _IC1,
+          bool _IC2,
+          class _Pred,
+          class _Proj1,
+          class _Proj2,
+          __enable_if_t<(is_same<_Pred, __equal_to>::value
+#  if _LIBCPP_STD_VER >= 20
+                         || is_same<_Pred, ranges::equal_to>::value
+#  endif
+                         ) &&
+                            __desugars_to_v<__equal_tag, _Pred, bool, bool> && __is_identity<_Proj1>::value &&
+                            __is_identity<_Proj2>::value,
+                        int> = 0>
+[[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_impl(
+    __bit_iterator<_Cp, _IC1> __first1,
+    __bit_iterator<_Cp, _IC1> __last1,
+    __bit_iterator<_Cp, _IC2> __first2,
+    __bit_iterator<_Cp, _IC2>,
+    _Pred&,
+    _Proj1&,
+    _Proj2&) {
+  return std::__equal_iter_impl(__first1, __last1, __first2, __equal_to());
+}
+
 template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>
 [[__nodiscard__]] inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
 equal(_InputIterator1 __first1,
diff --git a/libcxx/include/__bit_reference b/libcxx/include/__bit_reference
index 67abb023122edf..8a0c4c93bcbaf7 100644
--- a/libcxx/include/__bit_reference
+++ b/libcxx/include/__bit_reference
@@ -10,7 +10,9 @@
 #ifndef _LIBCPP___BIT_REFERENCE
 #define _LIBCPP___BIT_REFERENCE
 
+#include <__algorithm/comp.h>
 #include <__algorithm/copy_n.h>
+#include <__algorithm/equal.h>
 #include <__algorithm/min.h>
 #include <__bit/countr.h>
 #include <__compare/ordering.h>
@@ -22,7 +24,9 @@
 #include <__memory/construct_at.h>
 #include <__memory/pointer_traits.h>
 #include <__type_traits/conditional.h>
+#include <__type_traits/enable_if.h>
 #include <__type_traits/is_constant_evaluated.h>
+#include <__type_traits/is_same.h>
 #include <__type_traits/void_t.h>
 #include <__utility/swap.h>
 
@@ -669,127 +673,6 @@ rotate(__bit_iterator<_Cp, false> __first, __bit_iterator<_Cp, false> __middle,
   return __r;
 }
 
-// equal
-
-template <class _Cp, bool _IC1, bool _IC2>
-_LIBCPP_CONSTEXPR_SINCE_CXX20 _LIBCPP_HIDE_FROM_ABI bool __equal_unaligned(
-    __bit_iterator<_Cp, _IC1> __first1, __bit_iterator<_Cp, _IC1> __last1, __bit_iterator<_Cp, _IC2> __first2) {
-  using _It             = __bit_iterator<_Cp, _IC1>;
-  using difference_type = typename _It::difference_type;
-  using __storage_type  = typename _It::__storage_type;
-
-  const int __bits_per_word = _It::__bits_per_word;
-  difference_type __n       = __last1 - __first1;
-  if (__n > 0) {
-    // do first word
-    if (__first1.__ctz_ != 0) {
-      unsigned __clz_f     = __bits_per_word - __first1.__ctz_;
-      difference_type __dn = std::min(static_cast<difference_type>(__clz_f), __n);
-      __n -= __dn;
-      __storage_type __m   = (~__storage_type(0) << __first1.__ctz_) & (~__storage_type(0) >> (__clz_f - __dn));
-      __storage_type __b   = *__first1.__seg_ & __m;
-      unsigned __clz_r     = __bits_per_word - __first2.__ctz_;
-      __storage_type __ddn = std::min<__storage_type>(__dn, __clz_r);
-      __m                  = (~__storage_type(0) << __first2.__ctz_) & (~__storage_type(0) >> (__clz_r - __ddn));
-      if (__first2.__ctz_ > __first1.__ctz_) {
-        if ((*__first2.__seg_ & __m) != (__b << (__first2.__ctz_ - __first1.__ctz_)))
-          return false;
-      } else {
-        if ((*__first2.__seg_ & __m) != (__b >> (__first1.__ctz_ - __first2.__ctz_)))
-          return false;
-      }
-      __first2.__seg_ += (__ddn + __first2.__ctz_) / __bits_per_word;
-      __first2.__ctz_ = static_cast<unsigned>((__ddn + __first2.__ctz_) % __bits_per_word);
-      __dn -= __ddn;
-      if (__dn > 0) {
-        __m = ~__storage_type(0) >> (__bits_per_word - __dn);
-        if ((*__first2.__seg_ & __m) != (__b >> (__first1.__ctz_ + __ddn)))
-          return false;
-        __first2.__ctz_ = static_cast<unsigned>(__dn);
-      }
-      ++__first1.__seg_;
-      // __first1.__ctz_ = 0;
-    }
-    // __first1.__ctz_ == 0;
-    // do middle words
-    unsigned __clz_r   = __bits_per_word - __first2.__ctz_;
-    __storage_type __m = ~__storage_type(0) << __first2.__ctz_;
-    for (; __n >= __bits_per_word; __n -= __bits_per_word, ++__first1.__seg_) {
-      __storage_type __b = *__first1.__seg_;
-      if ((*__first2.__seg_ & __m) != (__b << __first2.__ctz_))
-        return false;
-      ++__first2.__seg_;
-      if ((*__first2.__seg_ & ~__m) != (__b >> __clz_r))
-        return false;
-    }
-    // do last word
-    if (__n > 0) {
-      __m                 = ~__storage_type(0) >> (__bits_per_word - __n);
-      __storage_type __b  = *__first1.__seg_ & __m;
-      __storage_type __dn = std::min(__n, static_cast<difference_type>(__clz_r));
-      __m                 = (~__storage_type(0) << __first2.__ctz_) & (~__storage_type(0) >> (__clz_r - __dn));
-      if ((*__first2.__seg_ & __m) != (__b << __first2.__ctz_))
-        return false;
-      __first2.__seg_ += (__dn + __first2.__ctz_) / __bits_per_word;
-      __first2.__ctz_ = static_cast<unsigned>((__dn + __first2.__ctz_) % __bits_per_word);
-      __n -= __dn;
-      if (__n > 0) {
-        __m = ~__storage_type(0) >> (__bits_per_word - __n);
-        if ((*__first2.__seg_ & __m) != (__b >> __dn))
-          return false;
-      }
-    }
-  }
-  return true;
-}
-
-template <class _Cp, bool _IC1, bool _IC2>
-_LIBCPP_CONSTEXPR_SINCE_CXX20 _LIBCPP_HIDE_FROM_ABI bool __equal_aligned(
-    __bit_iterator<_Cp, _IC1> __first1, __bit_iterator<_Cp, _IC1> __last1, __bit_iterator<_Cp, _IC2> __first2) {
-  using _It             = __bit_iterator<_Cp, _IC1>;
-  using difference_type = typename _It::difference_type;
-  using __storage_type  = typename _It::__storage_type;
-
-  const int __bits_per_word = _It::__bits_per_word;
-  difference_type __n       = __last1 - __first1;
-  if (__n > 0) {
-    // do first word
-    if (__first1.__ctz_ != 0) {
-      unsigned __clz       = __bits_per_word - __first1.__ctz_;
-      difference_type __dn = std::min(static_cast<difference_type>(__clz), __n);
-      __n -= __dn;
-      __storage_type __m = (~__storage_type(0) << __first1.__ctz_) & (~__storage_type(0) >> (__clz - __dn));
-      if ((*__first2.__seg_ & __m) != (*__first1.__seg_ & __m))
-        return false;
-      ++__first2.__seg_;
-      ++__first1.__seg_;
-      // __first1.__ctz_ = 0;
-      // __first2.__ctz_ = 0;
-    }
-    // __first1.__ctz_ == 0;
-    // __first2.__ctz_ == 0;
-    // do middle words
-    for (; __n >= __bits_per_word; __n -= __bits_per_word, ++__first1.__seg_, ++__first2.__seg_)
-      if (*__first2.__seg_ != *__first1.__seg_)
-        return false;
-    // do last word
-    if (__n > 0) {
-      __storage_type __m = ~__storage_type(0) >> (__bits_per_word - __n);
-      if ((*__first2.__seg_ & __m) != (*__first1.__seg_ & __m))
-        return false;
-    }
-  }
-  return true;
-}
-
-template <class _Cp, bool _IC1, bool _IC2>
-inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
-equal(__bit_iterator<_Cp, _IC1> __first1, __bit_iterator<_Cp, _IC1> __last1, __bit_iterator<_Cp, _IC2> __first2) {
-  if (__first1.__ctz_ == __first2.__ctz_)
-    return std::__equal_aligned(__first1, __last1, __first2);
-  return std::__equal_unaligned(__first1, __last1, __first2);
-}
-
 template <class _Cp, bool _IsConst, typename _Cp::__storage_type>
 class __bit_iterator {
 public:
@@ -1018,9 +901,13 @@ private:
   template <class _Dp, bool _IC1, bool _IC2>
   _LIBCPP_CONSTEXPR_SINCE_CXX20 friend bool
       __equal_unaligned(__bit_iterator<_Dp, _IC1>, __bit_iterator<_Dp, _IC1>, __bit_iterator<_Dp, _IC2>);
-  template <class _Dp, bool _IC1, bool _IC2>
-  _LIBCPP_CONSTEXPR_SINCE_CXX20 friend bool
-      equal(__bit_iterator<_Dp, _IC1>, __bit_iterator<_Dp, _IC1>, __bit_iterator<_Dp, _IC2>);
+  template <class _Dp,
+            bool _IC1,
+            bool _IC2,
+            class _BinaryPredicate,
+            __enable_if_t<std::is_same<_BinaryPredicate, __equal_to>::value, int> >
+  _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 friend bool __equal_iter_impl(
+      __bit_iterator<_Dp, _IC1>, __bit_iterator<_Dp, _IC1>, __bit_iterator<_Dp, _IC2>, _BinaryPredicate);
   template <bool _ToFind, class _Dp, bool _IC>
   _LIBCPP_CONSTEXPR_SINCE_CXX20 friend __bit_iterator<_Dp, _IC>
       __find_bool(__bit_iterator<_Dp, _IC>, typename __size_difference_type_traits<_Dp>::size_type);
diff --git a/libcxx/include/bitset b/libcxx/include/bitset
index 10576eb80bf2ee..a8c499df04232f 100644
--- a/libcxx/include/bitset
+++ b/libcxx/include/bitset
@@ -130,6 +130,7 @@ template <size_t N> struct hash<std::bitset<N>>;
 #  include <__cxx03/bitset>
 #else
 #  include <__algorithm/count.h>
+#  include <__algorithm/equal.h>
 #  include <__algorithm/fill.h>
 #  include <__algorithm/fill_n.h>
 #  include <__algorithm/find.h>
diff --git a/libcxx/test/benchmarks/algorithms/equal.bench.cpp b/libcxx/test/benchmarks/algorithms/equal.bench.cpp
index 2dc11585c15c7f..ac3aa28bb28b39 100644
--- a/libcxx/test/benchmarks/algorithms/equal.bench.cpp
+++ b/libcxx/test/benchmarks/algorithms/equal.bench.cpp
@@ -45,4 +45,55 @@ static void bm_ranges_equal(benchmark::State& state) {
 }
 BENCHMARK(bm_ranges_equal)->DenseRange(1, 8)->Range(16, 1 << 20);
 
+static void bm_ranges_equal_vb_aligned(benchmark::State& state) {
+  auto n = state.range();
+  std::vector<bool> vec1(n, true);
+  std::vector<bool> vec2(n, true);
+  for (auto _ : state) {
+    benchmark::DoNotOptimize(std::ranges::equal(vec1, vec2));
+    benchmark::DoNotOptimize(&vec1);
+    benchmark::DoNotOptimize(&vec2);
+  }
+}
+
+static void bm_ranges_equal_vb_unaligned(benchmark::State& state) {
+  auto n = state.range();
+  std::vector<bool> vec1(n, true);
+  std::vector<bool> vec2(n + 8, true);
+  auto beg1 = std::ranges::begin(vec1);
+  auto end1 = std::ranges::end(vec1);
+  auto beg2 = std::ranges::begin(vec2) + 4;
+  auto end2 = std::ranges::end(vec2) - 4;
+  for (auto _ : state) {
+    benchmark::DoNotOptimize(std::ranges::equal(beg1, end1, beg2, end2));
+    benchmark::DoNotOptimize(&vec1);
+    benchmark::DoNotOptimize(&vec2);
+  }
+}
+
+// Test std::ranges::equal for vector<bool>::iterator
+BENCHMARK(bm_ranges_equal_vb_aligned)->Range(8, 1 << 20);
+BENCHMARK(bm_ranges_equal_vb_unaligned)->Range(8, 1 << 20);
+
+static void bm_equal_vb(benchmark::State& state, bool aligned) {
+  auto n = state.range();
+  std::vector<bool> vec1(n, true);
+  std::vector<bool> vec2(aligned ? n : n + 8, true);
+  auto beg1 = vec1.begin();
+  auto end1 = vec1.end();
+  auto beg2 = aligned ? vec2.begin() : vec2.begin() + 4;
+  for (auto _ : state) {
+    benchmark::DoNotOptimize(std::equal(beg1, end1, beg2));
+    benchmark::DoNotOptimize(&vec1);
+    benchmark::DoNotOptimize(&vec2);
+  }
+}
+
+static void bm_equal_vb_aligned(benchmark::State& state) { bm_equal_vb(state, true); }
+static void bm_equal_vb_unaligned(benchmark::State& state) { bm_equal_vb(state, false); }
+
+// Test std::equal for vector<bool>::iterator
+BENCHMARK(bm_equal_vb_aligned)->Range(8, 1 << 20);
+BENCHMARK(bm_equal_vb_unaligned)->Range(8, 1 << 20);
+
 BENCHMARK_MAIN();
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.equal/equal.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.equal/equal.pass.cpp
index c3ba3f89b4de3c..a88f041013da6c 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/alg.equal/equal.pass.cpp
+++ b/libcxx/test/std/algorithms/alg.nonmodifying/alg.equal/equal.pass.cpp
@@ -28,6 +28,7 @@
 #include <algorithm>
 #include <cassert>
 #include <functional>
+#include <vector>
 
 #include "test_iterators.h"
 #include "test_macros.h"
@@ -123,6 +124,30 @@ class trivially_equality_comparable {
 
 #endif
 
+template <std::size_t N>
+TEST_CONSTEXPR_CXX20 void test_vector_bool() {
+  std::vector<bool> in(N, false);
+  for (std::size_t i = 0; i < N; i += 2)
+    in[i] = true;
+
+  { // Test equal() with aligned bytes
+    std::vector<bool> out = in;
+    assert(std::equal(in.begin(), in.end(), out.begin()));
+#if TEST_STD_VER >= 14
+    assert(std::equal(in.begin(), in.end(), out.begin(), out.end()));
+#endif
+  }
+
+  { // Test equal() with unaligned bytes
+    std::vector<bool> out(N + 8);
+    std::copy(in.begin(), in.end(), out.begin() + 4);
+    assert(std::equal(in.begin(), in.end(), out.begin() + 4));
+#if TEST_STD_VER >= 14
+    assert(std::equal(in.begin(), in.end(), out.begin() + 4, out.end() - 4));
+#endif
+  }
+}
+
 TEST_CONSTEXPR_CXX20 bool test() {
   types::for_each(types::cpp17_input_iterator_list<int*>(), TestIter2<int, types::cpp17_input_iterator_list<int*> >());
   types::for_each(
@@ -138,6 +163,14 @@ TEST_CONSTEXPR_CXX20 bool test() {
       TestIter2<trivially_equality_comparable, types::cpp17_input_iterator_list<trivially_equality_comparable*>>{});
 #endif
 
+  { // Test vector<bool>::iterator optimization
+    test_vector_bool<8>();
+    test_vector_bool<16>();
+    test_vector_bool<32>();
+    test_vector_bool<64>();
+    test_vector_bool<256>();
+  }
+
   return true;
 }
 
diff --git a/libcxx/test/std/algorithms/alg.nonmodifying/alg.equal/ranges.equal.pass.cpp b/libcxx/test/std/algorithms/alg.nonmodifying/alg.equal/ranges.equal.pass.cpp
index f36cd2e0896552..37c3677b445c00 100644
--- a/libcxx/test/std/algorithms/alg.nonmodifying/alg.equal/ra...
[truncated]

@winner245 winner245 force-pushed the optimize-ranges-equal branch 2 times, most recently from 2b9442a to c47c0d8 Compare February 3, 2025 22:27
Copy link
Member

@ldionne ldionne left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with minor comments, thanks!

@winner245 winner245 force-pushed the optimize-ranges-equal branch 5 times, most recently from 2019761 to a6ee161 Compare February 8, 2025 16:17
@winner245
Copy link
Contributor Author

I've updated the tests for the fill, fill_n algorithms to temporarily use the standard std::equal in their equality assertions. This change is necessary because the std::equal algorithm with the __bit_iterator optimization fails to correctly compare vector<bool> instances with storage types smaller than int, as tracked in #126369. Once this issue is resolved, we should revert to the __bit_iterator optimized version in the fill and fill_n tests. This is clearly stated in the corresponding test files with a FIXME comment.

The purpose of this change is to ensure that this PR can be merged independently of the issue resolution.

Copy link
Member

@ldionne ldionne left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change LGTM but I'd like to see it again after the operator== change.

@winner245 winner245 force-pushed the optimize-ranges-equal branch from a6ee161 to cb3dc20 Compare February 20, 2025 20:37
@winner245 winner245 force-pushed the optimize-ranges-equal branch from cb3dc20 to 915e176 Compare February 20, 2025 20:51
Copy link
Member

@ldionne ldionne left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@ldionne ldionne merged commit 7717a54 into llvm:main Feb 26, 2025
86 checks passed
@winner245 winner245 deleted the optimize-ranges-equal branch February 26, 2025 21:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
libc++ libc++ C++ Standard Library. Not GNU libstdc++. Not libc++abi. performance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants