Skip to content

Commit 6608908

Browse files
committed
[ADT] Extend EnumeratedArray
EnumeratedArray is essentially a wrapper around a fixed-size array that uses enum values instead of integers as indices. * Add iterator support (begin/end/rbegin/rend), which enables the use of iterator/range based algorithms on EnumeratedArrays. * Add common container typedefs (value_type etc.), allowing drop-in replacements of other containers in cases relying on these. * Add a constructor that takes an std::initializer_list<T>. * Make the size() function const. * Add empty(). Iterator support slightly lowers the protection non-type-safe accesses, because iterator arithmetic is not enum-based, and one can now use *(begin() + IntIndex). However, it is and was also always possible to just cast arbitrary indices to the enum type. Differential Revision: https://reviews.llvm.org/D135594
1 parent a4b0100 commit 6608908

File tree

2 files changed

+108
-2
lines changed

2 files changed

+108
-2
lines changed

llvm/include/llvm/ADT/EnumeratedArray.h

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#define LLVM_ADT_ENUMERATEDARRAY_H
1717

1818
#include <cassert>
19+
#include <iterator>
1920

2021
namespace llvm {
2122

@@ -24,14 +25,33 @@ template <typename ValueType, typename Enumeration,
2425
IndexType Size = 1 + static_cast<IndexType>(LargestEnum)>
2526
class EnumeratedArray {
2627
public:
28+
using iterator = ValueType *;
29+
using const_iterator = const ValueType *;
30+
31+
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
32+
using reverse_iterator = std::reverse_iterator<iterator>;
33+
34+
using value_type = ValueType;
35+
using reference = ValueType &;
36+
using const_reference = const ValueType &;
37+
using pointer = ValueType *;
38+
using const_pointer = const ValueType *;
39+
2740
EnumeratedArray() = default;
2841
EnumeratedArray(ValueType V) {
2942
for (IndexType IX = 0; IX < Size; ++IX) {
3043
Underlying[IX] = V;
3144
}
3245
}
46+
EnumeratedArray(std::initializer_list<ValueType> Init) {
47+
assert(Init.size() == Size && "Incorrect initializer size");
48+
for (IndexType IX = 0; IX < Size; ++IX) {
49+
Underlying[IX] = *(Init.begin() + IX);
50+
}
51+
}
52+
3353
const ValueType &operator[](Enumeration Index) const {
34-
auto IX = static_cast<const IndexType>(Index);
54+
auto IX = static_cast<IndexType>(Index);
3555
assert(IX >= 0 && IX < Size && "Index is out of bounds.");
3656
return Underlying[IX];
3757
}
@@ -40,7 +60,23 @@ class EnumeratedArray {
4060
static_cast<const EnumeratedArray<ValueType, Enumeration, LargestEnum,
4161
IndexType, Size> &>(*this)[Index]);
4262
}
43-
IndexType size() { return Size; }
63+
IndexType size() const { return Size; }
64+
bool empty() const { return size() == 0; }
65+
66+
iterator begin() { return Underlying; }
67+
const_iterator begin() const { return Underlying; }
68+
69+
iterator end() { return begin() + size(); }
70+
const_iterator end() const { return begin() + size(); }
71+
72+
reverse_iterator rbegin() { return reverse_iterator(end()); }
73+
const_reverse_iterator rbegin() const {
74+
return const_reverse_iterator(end());
75+
}
76+
reverse_iterator rend() { return reverse_iterator(begin()); }
77+
const_reverse_iterator rend() const {
78+
return const_reverse_iterator(begin());
79+
}
4480

4581
private:
4682
ValueType Underlying[Size];

llvm/unittests/ADT/EnumeratedArrayTest.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "llvm/ADT/EnumeratedArray.h"
14+
#include "llvm/ADT/iterator_range.h"
15+
#include "gmock/gmock.h"
1416
#include "gtest/gtest.h"
17+
#include <type_traits>
1518

1619
namespace llvm {
1720

@@ -46,6 +49,73 @@ TEST(EnumeratedArray, InitAndIndex) {
4649
EXPECT_TRUE(Array2[Colors::Red]);
4750
EXPECT_FALSE(Array2[Colors::Blue]);
4851
EXPECT_TRUE(Array2[Colors::Green]);
52+
53+
EnumeratedArray<float, Colors, Colors::Last, size_t> Array3 = {10.0, 11.0,
54+
12.0};
55+
EXPECT_EQ(Array3[Colors::Red], 10.0);
56+
EXPECT_EQ(Array3[Colors::Blue], 11.0);
57+
EXPECT_EQ(Array3[Colors::Green], 12.0);
58+
}
59+
60+
//===--------------------------------------------------------------------===//
61+
// Test size and empty function
62+
//===--------------------------------------------------------------------===//
63+
64+
TEST(EnumeratedArray, Size) {
65+
66+
enum class Colors { Red, Blue, Green, Last = Green };
67+
68+
EnumeratedArray<float, Colors, Colors::Last, size_t> Array;
69+
const auto &ConstArray = Array;
70+
71+
EXPECT_EQ(ConstArray.size(), 3u);
72+
EXPECT_EQ(ConstArray.empty(), false);
73+
}
74+
75+
//===--------------------------------------------------------------------===//
76+
// Test iterators
77+
//===--------------------------------------------------------------------===//
78+
79+
TEST(EnumeratedArray, Iterators) {
80+
81+
enum class Colors { Red, Blue, Green, Last = Green };
82+
83+
EnumeratedArray<float, Colors, Colors::Last, size_t> Array;
84+
const auto &ConstArray = Array;
85+
86+
Array[Colors::Red] = 1.0;
87+
Array[Colors::Blue] = 2.0;
88+
Array[Colors::Green] = 3.0;
89+
90+
EXPECT_THAT(Array, testing::ElementsAre(1.0, 2.0, 3.0));
91+
EXPECT_THAT(ConstArray, testing::ElementsAre(1.0, 2.0, 3.0));
92+
93+
EXPECT_THAT(make_range(Array.rbegin(), Array.rend()),
94+
testing::ElementsAre(3.0, 2.0, 1.0));
95+
EXPECT_THAT(make_range(ConstArray.rbegin(), ConstArray.rend()),
96+
testing::ElementsAre(3.0, 2.0, 1.0));
4997
}
5098

99+
//===--------------------------------------------------------------------===//
100+
// Test typedefs
101+
//===--------------------------------------------------------------------===//
102+
103+
namespace {
104+
105+
enum class Colors { Red, Blue, Green, Last = Green };
106+
107+
using Array = EnumeratedArray<float, Colors, Colors::Last, size_t>;
108+
109+
static_assert(std::is_same<Array::value_type, float>::value,
110+
"Incorrect value_type type");
111+
static_assert(std::is_same<Array::reference, float &>::value,
112+
"Incorrect reference type!");
113+
static_assert(std::is_same<Array::pointer, float *>::value,
114+
"Incorrect pointer type!");
115+
static_assert(std::is_same<Array::const_reference, const float &>::value,
116+
"Incorrect const_reference type!");
117+
static_assert(std::is_same<Array::const_pointer, const float *>::value,
118+
"Incorrect const_pointer type!");
119+
} // namespace
120+
51121
} // namespace llvm

0 commit comments

Comments
 (0)