|
11 | 11 | //===----------------------------------------------------------------------===//
|
12 | 12 | #include "src/__support/CPP/new.h"
|
13 | 13 | #include "src/__support/CPP/optional.h"
|
| 14 | +#include "src/__support/CPP/string.h" |
| 15 | +#include "src/__support/CPP/utility/forward.h" |
14 | 16 | #include "src/__support/HashTable/table.h"
|
15 |
| -#include "src/string/memcpy.h" |
16 |
| -#include <search.h> |
17 | 17 | #include <stdint.h>
|
18 | 18 | namespace LIBC_NAMESPACE {
|
19 | 19 |
|
20 |
| -enum class Action { Find, Insert, CrossCheck }; |
21 |
| -static uint8_t *global_buffer = nullptr; |
22 |
| -static size_t remaining = 0; |
| 20 | +template <typename T> class UniquePtr { |
| 21 | + T *ptr; |
23 | 22 |
|
24 |
| -static cpp::optional<uint8_t> next_u8() { |
25 |
| - if (remaining == 0) |
26 |
| - return cpp::nullopt; |
27 |
| - uint8_t result = *global_buffer; |
28 |
| - global_buffer++; |
29 |
| - remaining--; |
30 |
| - return result; |
31 |
| -} |
| 23 | +public: |
| 24 | + UniquePtr(T *ptr) : ptr(ptr) {} |
| 25 | + ~UniquePtr() { delete ptr; } |
| 26 | + UniquePtr(UniquePtr &&other) : ptr(other.ptr) { other.ptr = nullptr; } |
| 27 | + UniquePtr &operator=(UniquePtr &&other) { |
| 28 | + delete ptr; |
| 29 | + ptr = other.ptr; |
| 30 | + other.ptr = nullptr; |
| 31 | + return *this; |
| 32 | + } |
| 33 | + T *operator->() { return ptr; } |
| 34 | + template <typename... U> static UniquePtr create(U &&...x) { |
| 35 | + AllocChecker ac; |
| 36 | + T *ptr = new (ac) T(cpp::forward<U>(x)...); |
| 37 | + if (!ac) |
| 38 | + return {nullptr}; |
| 39 | + return UniquePtr(ptr); |
| 40 | + } |
| 41 | + operator bool() { return ptr != nullptr; } |
| 42 | + T *get() { return ptr; } |
| 43 | +}; |
32 | 44 |
|
33 |
| -static cpp::optional<uint64_t> next_uint64() { |
34 |
| - uint64_t result; |
35 |
| - if (remaining < sizeof(result)) |
36 |
| - return cpp::nullopt; |
37 |
| - memcpy(&result, global_buffer, sizeof(result)); |
38 |
| - global_buffer += sizeof(result); |
39 |
| - remaining -= sizeof(result); |
40 |
| - return result; |
41 |
| -} |
| 45 | +// a tagged union |
| 46 | +struct Action { |
| 47 | + enum class Tag { Find, Insert, CrossCheck } tag; |
| 48 | + cpp::string key; |
| 49 | + UniquePtr<Action> next; |
| 50 | + Action(Tag tag, cpp::string key, UniquePtr<Action> next) |
| 51 | + : tag(tag), key(cpp::move(key)), next(cpp::move(next)) {} |
| 52 | +}; |
42 | 53 |
|
43 |
| -static cpp::optional<Action> next_action() { |
44 |
| - if (cpp::optional<uint8_t> action = next_u8()) { |
45 |
| - switch (*action % 3) { |
46 |
| - case 0: |
47 |
| - return Action::Find; |
48 |
| - case 1: |
49 |
| - return Action::Insert; |
50 |
| - case 2: |
51 |
| - return Action::CrossCheck; |
52 |
| - } |
| 54 | +static struct { |
| 55 | + UniquePtr<Action> actions = nullptr; |
| 56 | + size_t remaining; |
| 57 | + const char *buffer; |
| 58 | + |
| 59 | + template <typename T> cpp::optional<T> next() { |
| 60 | + static_assert(cpp::is_integral<T>::value, "T must be an integral type"); |
| 61 | + union { |
| 62 | + T result; |
| 63 | + char data[sizeof(T)]; |
| 64 | + }; |
| 65 | + if (remaining < sizeof(result)) |
| 66 | + return cpp::nullopt; |
| 67 | + for (size_t i = 0; i < sizeof(result); i++) |
| 68 | + data[i] = buffer[i]; |
| 69 | + buffer += sizeof(result); |
| 70 | + remaining -= sizeof(result); |
| 71 | + return result; |
53 | 72 | }
|
54 |
| - return cpp::nullopt; |
55 |
| -} |
56 | 73 |
|
57 |
| -static cpp::optional<char *> next_cstr() { |
58 |
| - char *result = reinterpret_cast<char *>(global_buffer); |
59 |
| - if (cpp::optional<uint8_t> len = next_u8()) { |
60 |
| - uint64_t length; |
61 |
| - for (length = 0; length < *len; length++) { |
62 |
| - if (length >= remaining) |
63 |
| - return cpp::nullopt; |
64 |
| - if (*global_buffer == '\0') |
| 74 | + cpp::optional<cpp::string> next_string() { |
| 75 | + if (cpp::optional<uint16_t> len = next<uint16_t>()) { |
| 76 | + uint64_t length; |
| 77 | + for (length = 0; length < *len && length < remaining; length++) |
| 78 | + if (buffer[length] == '\0') |
| 79 | + break; |
| 80 | + cpp::string result(buffer, length); |
| 81 | + result += '\0'; |
| 82 | + buffer += length; |
| 83 | + remaining -= length; |
| 84 | + return result; |
| 85 | + } |
| 86 | + return cpp::nullopt; |
| 87 | + } |
| 88 | + Action *next_action() { |
| 89 | + if (cpp::optional<uint8_t> action = next<uint8_t>()) { |
| 90 | + switch (*action % 3) { |
| 91 | + case 0: { |
| 92 | + if (cpp::optional<cpp::string> key = next_string()) |
| 93 | + actions = UniquePtr<Action>::create( |
| 94 | + Action::Tag::Find, cpp::move(*key), cpp::move(actions)); |
| 95 | + else |
| 96 | + return nullptr; |
65 | 97 | break;
|
| 98 | + } |
| 99 | + case 1: { |
| 100 | + if (cpp::optional<cpp::string> key = next_string()) |
| 101 | + actions = UniquePtr<Action>::create( |
| 102 | + Action::Tag::Insert, cpp::move(*key), cpp::move(actions)); |
| 103 | + else |
| 104 | + return nullptr; |
| 105 | + break; |
| 106 | + } |
| 107 | + case 2: { |
| 108 | + actions = UniquePtr<Action>::create(Action::Tag::CrossCheck, "", |
| 109 | + cpp::move(actions)); |
| 110 | + break; |
| 111 | + } |
| 112 | + } |
| 113 | + return actions.get(); |
66 | 114 | }
|
67 |
| - if (length >= remaining) |
68 |
| - return cpp::nullopt; |
69 |
| - global_buffer[length] = '\0'; |
70 |
| - global_buffer += length + 1; |
71 |
| - remaining -= length + 1; |
72 |
| - return result; |
| 115 | + return nullptr; |
73 | 116 | }
|
74 |
| - return cpp::nullopt; |
75 |
| -} |
| 117 | +} global_status; |
76 | 118 |
|
77 |
| -#define get_value(op) \ |
78 |
| - __extension__({ \ |
79 |
| - auto val = op(); \ |
80 |
| - if (!val) \ |
81 |
| - return 0; \ |
82 |
| - *val; \ |
83 |
| - }) |
| 119 | +class HashTable { |
| 120 | + internal::HashTable *table; |
84 | 121 |
|
85 |
| -template <typename Fn> struct CleanUpHook { |
86 |
| - cpp::optional<Fn> fn; |
87 |
| - ~CleanUpHook() { |
88 |
| - if (fn) |
89 |
| - (*fn)(); |
90 |
| - } |
91 |
| - CleanUpHook(Fn fn) : fn(cpp::move(fn)) {} |
92 |
| - CleanUpHook(const CleanUpHook &) = delete; |
93 |
| - CleanUpHook(CleanUpHook &&other) : fn(cpp::move(other.fn)) { |
94 |
| - other.fn = cpp::nullopt; |
| 122 | +public: |
| 123 | + HashTable(uint64_t size, uint64_t seed) |
| 124 | + : table(internal::HashTable::allocate(size, seed)) {} |
| 125 | + HashTable(internal::HashTable *table) : table(table) {} |
| 126 | + ~HashTable() { internal::HashTable::deallocate(table); } |
| 127 | + HashTable(HashTable &&other) : table(other.table) { other.table = nullptr; } |
| 128 | + bool is_valid() const { return table != nullptr; } |
| 129 | + ENTRY *find(const char *key) { return table->find(key); } |
| 130 | + ENTRY *insert(const ENTRY &entry) { |
| 131 | + return internal::HashTable::insert(this->table, entry); |
95 | 132 | }
|
| 133 | + using iterator = internal::HashTable::iterator; |
| 134 | + iterator begin() const { return table->begin(); } |
| 135 | + iterator end() const { return table->end(); } |
96 | 136 | };
|
97 | 137 |
|
98 |
| -#define register_cleanup(ID, ...) \ |
99 |
| - auto cleanup_hook##ID = __extension__({ \ |
100 |
| - auto a = __VA_ARGS__; \ |
101 |
| - CleanUpHook<decltype(a)>(cpp::move(a)); \ |
102 |
| - }); |
| 138 | +HashTable next_hashtable() { |
| 139 | + if (cpp::optional<uint16_t> size = global_status.next<uint16_t>()) |
| 140 | + if (cpp::optional<uint64_t> seed = global_status.next<uint64_t>()) |
| 141 | + return HashTable(*size, *seed); |
103 | 142 |
|
104 |
| -static void trap_with_message(const char *msg) { __builtin_trap(); } |
| 143 | + return HashTable(0, 0); |
| 144 | +} |
105 | 145 |
|
106 | 146 | extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
107 |
| - AllocChecker ac; |
108 |
| - global_buffer = static_cast<uint8_t *>(::operator new(size, ac)); |
109 |
| - register_cleanup(0, [global_buffer = global_buffer, size] { |
110 |
| - ::operator delete(global_buffer, size); |
111 |
| - }); |
112 |
| - if (!ac) |
| 147 | + char key[] = "key"; |
| 148 | + global_status.buffer = reinterpret_cast<const char *>(data); |
| 149 | + global_status.remaining = size; |
| 150 | + HashTable table_a = next_hashtable(); |
| 151 | + HashTable table_b = next_hashtable(); |
| 152 | + if (!table_a.is_valid() || !table_b.is_valid()) |
113 | 153 | return 0;
|
114 |
| - memcpy(global_buffer, data, size); |
115 | 154 |
|
116 |
| - remaining = size; |
117 |
| - uint64_t size_a = get_value(next_u8); |
118 |
| - uint64_t size_b = get_value(next_u8); |
119 |
| - uint64_t rand_a = get_value(next_uint64); |
120 |
| - uint64_t rand_b = get_value(next_uint64); |
121 |
| - internal::HashTable *table_a = internal::HashTable::allocate(size_a, rand_a); |
122 |
| - register_cleanup(1, [&table_a] { |
123 |
| - if (table_a) |
124 |
| - internal::HashTable::deallocate(table_a); |
125 |
| - }); |
126 |
| - internal::HashTable *table_b = internal::HashTable::allocate(size_b, rand_b); |
127 |
| - register_cleanup(2, [&table_b] { |
128 |
| - if (table_b) |
129 |
| - internal::HashTable::deallocate(table_b); |
130 |
| - }); |
131 |
| - if (!table_a || !table_b) |
132 |
| - return 0; |
133 | 155 | for (;;) {
|
134 |
| - Action action = get_value(next_action); |
135 |
| - switch (action) { |
136 |
| - case Action::Find: { |
137 |
| - const char *key = get_value(next_cstr); |
138 |
| - if (static_cast<bool>(table_a->find(key)) != |
139 |
| - static_cast<bool>(table_b->find(key))) |
140 |
| - trap_with_message(key); |
| 156 | + Action *action = global_status.next_action(); |
| 157 | + if (!action) |
| 158 | + return 0; |
| 159 | + switch (action->tag) { |
| 160 | + case Action::Tag::Find: { |
| 161 | + if (table_a.find(action->key.c_str()) != |
| 162 | + table_b.find(action->key.c_str())) |
| 163 | + __builtin_trap(); |
141 | 164 | break;
|
142 | 165 | }
|
143 |
| - case Action::Insert: { |
144 |
| - char *key = get_value(next_cstr); |
145 |
| - ENTRY *a = internal::HashTable::insert(table_a, ENTRY{key, key}); |
146 |
| - ENTRY *b = internal::HashTable::insert(table_b, ENTRY{key, key}); |
| 166 | + case Action::Tag::Insert: { |
| 167 | + ENTRY *a = table_a.insert(ENTRY{key, key}); |
| 168 | + ENTRY *b = table_b.insert(ENTRY{key, key}); |
147 | 169 | if (a->data != b->data)
|
148 | 170 | __builtin_trap();
|
149 | 171 | break;
|
150 | 172 | }
|
151 |
| - case Action::CrossCheck: { |
152 |
| - for (ENTRY a : *table_a) { |
153 |
| - if (const ENTRY *b = table_b->find(a.key)) { |
154 |
| - if (a.data != b->data) |
155 |
| - __builtin_trap(); |
156 |
| - } |
157 |
| - } |
158 |
| - for (ENTRY b : *table_b) { |
159 |
| - if (const ENTRY *a = table_a->find(b.key)) { |
160 |
| - if (a->data != b.data) |
161 |
| - __builtin_trap(); |
162 |
| - } |
163 |
| - } |
| 173 | + case Action::Tag::CrossCheck: { |
| 174 | + for (ENTRY a : table_a) |
| 175 | + if (const ENTRY *b = table_b.find(a.key); a.data != b->data) |
| 176 | + __builtin_trap(); |
| 177 | + |
| 178 | + for (ENTRY b : table_b) |
| 179 | + if (const ENTRY *a = table_a.find(b.key); a->data != b.data) |
| 180 | + __builtin_trap(); |
| 181 | + |
164 | 182 | break;
|
165 | 183 | }
|
166 | 184 | }
|
167 | 185 | }
|
| 186 | + return 0; |
168 | 187 | }
|
169 | 188 |
|
170 | 189 | } // namespace LIBC_NAMESPACE
|
0 commit comments