Skip to content

Commit 981ce8f

Browse files
committed
[ADT] Fix const-correctness issues in zippy
This defines the iterator tuple based on the storage type of `zippy`, instead of its type arguments. This way, we can support temporaries that gets passed in and allow for them to be modified during iteration. Because the iterator types to the tuple storage can have different types when the storage is and isn't const, this defines a const iterator type and non-const `begin`/`end` functions. This way we avoid unintentional casts, e.g., trying to cast `vector<bool>::reference` to `vector<bool>::const_reference`, which may be unrelated types that are not convertible. This patch is a general and free-standing improvement but my primary use is in the implemention a version of `enumerate` that accepts multiple ranges: D144583. Reviewed By: dblaikie, zero9178 Differential Revision: https://reviews.llvm.org/D144834
1 parent 466b432 commit 981ce8f

File tree

2 files changed

+157
-14
lines changed

2 files changed

+157
-14
lines changed

llvm/include/llvm/ADT/STLExtras.h

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -856,33 +856,70 @@ class zip_shortest : public zip_common<zip_shortest<Iters...>, Iters...> {
856856
}
857857
};
858858

859+
/// Helper to obtain the iterator types for the tuple storage within `zippy`.
860+
template <template <typename...> class ItType, typename TupleStorageType,
861+
typename IndexSequence>
862+
struct ZippyIteratorTuple;
863+
864+
/// Partial specialization for non-const tuple storage.
865+
template <template <typename...> class ItType, typename... Args,
866+
std::size_t... Ns>
867+
struct ZippyIteratorTuple<ItType, std::tuple<Args...>,
868+
std::index_sequence<Ns...>> {
869+
using type = ItType<decltype(adl_begin(
870+
std::get<Ns>(declval<std::tuple<Args...> &>())))...>;
871+
};
872+
873+
/// Partial specialization for const tuple storage.
874+
template <template <typename...> class ItType, typename... Args,
875+
std::size_t... Ns>
876+
struct ZippyIteratorTuple<ItType, const std::tuple<Args...>,
877+
std::index_sequence<Ns...>> {
878+
using type = ItType<decltype(adl_begin(
879+
std::get<Ns>(declval<const std::tuple<Args...> &>())))...>;
880+
};
881+
859882
template <template <typename...> class ItType, typename... Args> class zippy {
883+
private:
884+
std::tuple<Args...> storage;
885+
using IndexSequence = std::index_sequence_for<Args...>;
886+
860887
public:
861-
using iterator = ItType<decltype(std::begin(std::declval<Args>()))...>;
888+
using iterator = typename ZippyIteratorTuple<ItType, decltype(storage),
889+
IndexSequence>::type;
890+
using const_iterator =
891+
typename ZippyIteratorTuple<ItType, const decltype(storage),
892+
IndexSequence>::type;
862893
using iterator_category = typename iterator::iterator_category;
863894
using value_type = typename iterator::value_type;
864895
using difference_type = typename iterator::difference_type;
865896
using pointer = typename iterator::pointer;
866897
using reference = typename iterator::reference;
898+
using const_reference = typename const_iterator::reference;
867899

868-
private:
869-
std::tuple<Args...> ts;
900+
zippy(Args &&...args) : storage(std::forward<Args>(args)...) {}
870901

902+
const_iterator begin() const { return begin_impl(IndexSequence{}); }
903+
iterator begin() { return begin_impl(IndexSequence{}); }
904+
const_iterator end() const { return end_impl(IndexSequence{}); }
905+
iterator end() { return end_impl(IndexSequence{}); }
906+
907+
private:
871908
template <size_t... Ns>
872-
iterator begin_impl(std::index_sequence<Ns...>) const {
873-
return iterator(std::begin(std::get<Ns>(ts))...);
909+
const_iterator begin_impl(std::index_sequence<Ns...>) const {
910+
return const_iterator(adl_begin(std::get<Ns>(storage))...);
874911
}
875-
template <size_t... Ns> iterator end_impl(std::index_sequence<Ns...>) const {
876-
return iterator(std::end(std::get<Ns>(ts))...);
912+
template <size_t... Ns> iterator begin_impl(std::index_sequence<Ns...>) {
913+
return iterator(adl_begin(std::get<Ns>(storage))...);
877914
}
878915

879-
public:
880-
zippy(Args &&... ts_) : ts(std::forward<Args>(ts_)...) {}
881-
882-
iterator begin() const {
883-
return begin_impl(std::index_sequence_for<Args...>{});
916+
template <size_t... Ns>
917+
const_iterator end_impl(std::index_sequence<Ns...>) const {
918+
return const_iterator(adl_end(std::get<Ns>(storage))...);
919+
}
920+
template <size_t... Ns> iterator end_impl(std::index_sequence<Ns...>) {
921+
return iterator(adl_end(std::get<Ns>(storage))...);
884922
}
885-
iterator end() const { return end_impl(std::index_sequence_for<Args...>{}); }
886923
};
887924

888925
} // end namespace detail

llvm/unittests/ADT/IteratorTest.cpp

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,19 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "llvm/ADT/ilist.h"
109
#include "llvm/ADT/iterator.h"
1110
#include "llvm/ADT/ArrayRef.h"
1211
#include "llvm/ADT/STLExtras.h"
1312
#include "llvm/ADT/SmallVector.h"
13+
#include "llvm/ADT/ilist.h"
14+
#include "gmock/gmock.h"
1415
#include "gtest/gtest.h"
1516
#include <optional>
17+
#include <type_traits>
18+
#include <vector>
1619

1720
using namespace llvm;
21+
using testing::ElementsAre;
1822

1923
namespace {
2024

@@ -430,6 +434,108 @@ TEST(ZipIteratorTest, ZipEqualBasic) {
430434
EXPECT_EQ(iters, 6u);
431435
}
432436

437+
template <typename T>
438+
constexpr bool IsConstRef =
439+
std::is_reference_v<T> && std::is_const_v<std::remove_reference_t<T>>;
440+
441+
template <typename T>
442+
constexpr bool IsBoolConstRef =
443+
std::is_same_v<llvm::remove_cvref_t<T>, std::vector<bool>::const_reference>;
444+
445+
/// Returns a `const` copy of the passed value. The `const` on the returned
446+
/// value is intentional here so that `MakeConst` can be used in range-for
447+
/// loops.
448+
template <typename T> const T MakeConst(T &&value) {
449+
return std::forward<T>(value);
450+
}
451+
452+
TEST(ZipIteratorTest, ZipEqualConstCorrectness) {
453+
const std::vector<unsigned> c_first = {3, 1, 4};
454+
std::vector<unsigned> first = c_first;
455+
const SmallVector<bool> c_second = {1, 1, 0};
456+
SmallVector<bool> second = c_second;
457+
458+
for (auto [a, b, c, d] : zip_equal(c_first, first, c_second, second)) {
459+
b = 0;
460+
d = true;
461+
static_assert(IsConstRef<decltype(a)>);
462+
static_assert(!IsConstRef<decltype(b)>);
463+
static_assert(IsConstRef<decltype(c)>);
464+
static_assert(!IsConstRef<decltype(d)>);
465+
}
466+
467+
EXPECT_THAT(first, ElementsAre(0, 0, 0));
468+
EXPECT_THAT(second, ElementsAre(true, true, true));
469+
470+
std::vector<bool> nemesis = {true, false, true};
471+
const std::vector<bool> c_nemesis = nemesis;
472+
473+
for (auto &&[a, b, c, d] : zip_equal(first, c_first, nemesis, c_nemesis)) {
474+
a = 2;
475+
c = true;
476+
static_assert(!IsConstRef<decltype(a)>);
477+
static_assert(IsConstRef<decltype(b)>);
478+
static_assert(!IsBoolConstRef<decltype(c)>);
479+
static_assert(IsBoolConstRef<decltype(d)>);
480+
}
481+
482+
EXPECT_THAT(first, ElementsAre(2, 2, 2));
483+
EXPECT_THAT(nemesis, ElementsAre(true, true, true));
484+
485+
unsigned iters = 0;
486+
for (const auto &[a, b, c, d] :
487+
zip_equal(first, c_first, nemesis, c_nemesis)) {
488+
static_assert(!IsConstRef<decltype(a)>);
489+
static_assert(IsConstRef<decltype(b)>);
490+
static_assert(!IsBoolConstRef<decltype(c)>);
491+
static_assert(IsBoolConstRef<decltype(d)>);
492+
++iters;
493+
}
494+
EXPECT_EQ(iters, 3u);
495+
iters = 0;
496+
497+
for (const auto &[a, b, c, d] :
498+
MakeConst(zip_equal(first, c_first, nemesis, c_nemesis))) {
499+
static_assert(!IsConstRef<decltype(a)>);
500+
static_assert(IsConstRef<decltype(b)>);
501+
static_assert(!IsBoolConstRef<decltype(c)>);
502+
static_assert(IsBoolConstRef<decltype(d)>);
503+
++iters;
504+
}
505+
EXPECT_EQ(iters, 3u);
506+
}
507+
508+
TEST(ZipIteratorTest, ZipEqualTemporaries) {
509+
unsigned iters = 0;
510+
511+
// These temporary ranges get moved into the `tuple<...> storage;` inside
512+
// `zippy`. From then on, we can use references obtained from this storage to
513+
// access them. This does not rely on any lifetime extensions on the
514+
// temporaries passed to `zip_equal`.
515+
for (auto [a, b, c] : zip_equal(SmallVector<int>{1, 2, 3}, std::string("abc"),
516+
std::vector<bool>{true, false, true})) {
517+
a = 3;
518+
b = 'c';
519+
c = false;
520+
static_assert(!IsConstRef<decltype(a)>);
521+
static_assert(!IsConstRef<decltype(b)>);
522+
static_assert(!IsBoolConstRef<decltype(c)>);
523+
++iters;
524+
}
525+
EXPECT_EQ(iters, 3u);
526+
iters = 0;
527+
528+
for (auto [a, b, c] :
529+
MakeConst(zip_equal(SmallVector<int>{1, 2, 3}, std::string("abc"),
530+
std::vector<bool>{true, false, true}))) {
531+
static_assert(IsConstRef<decltype(a)>);
532+
static_assert(IsConstRef<decltype(b)>);
533+
static_assert(IsBoolConstRef<decltype(c)>);
534+
++iters;
535+
}
536+
EXPECT_EQ(iters, 3u);
537+
}
538+
433539
#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
434540
// Check that an assertion is triggered when ranges passed to `zip_equal` differ
435541
// in length.

0 commit comments

Comments
 (0)