Skip to content

Commit 184fbd4

Browse files
fix
1 parent ea09412 commit 184fbd4

File tree

3 files changed

+160
-127
lines changed

3 files changed

+160
-127
lines changed

libc/fuzzing/__support/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,14 @@ add_libc_fuzzer(
1414
libc.src.__support.HashTable.table
1515
libc.src.string.memcpy
1616
)
17+
18+
add_libc_fuzzer(
19+
hashtable_opt_fuzz
20+
SRCS
21+
hashtable_fuzz.cpp
22+
DEPENDS
23+
libc.src.__support.HashTable.table
24+
libc.src.string.memcpy
25+
COMPILE_OPTIONS
26+
-D__LIBC_EXPLICIT_SIMD_OPT
27+
)

libc/fuzzing/__support/hashtable_fuzz.cpp

Lines changed: 140 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -11,160 +11,179 @@
1111
//===----------------------------------------------------------------------===//
1212
#include "src/__support/CPP/new.h"
1313
#include "src/__support/CPP/optional.h"
14+
#include "src/__support/CPP/string.h"
15+
#include "src/__support/CPP/utility/forward.h"
1416
#include "src/__support/HashTable/table.h"
15-
#include "src/string/memcpy.h"
16-
#include <search.h>
1717
#include <stdint.h>
1818
namespace LIBC_NAMESPACE {
1919

20-
enum class Action { Find, Insert, CrossCheck };
21-
static uint8_t *global_buffer = nullptr;
22-
static size_t remaining = 0;
20+
template <typename T> class UniquePtr {
21+
T *ptr;
2322

24-
static cpp::optional<uint8_t> next_u8() {
25-
if (remaining == 0)
26-
return cpp::nullopt;
27-
uint8_t result = *global_buffer;
28-
global_buffer++;
29-
remaining--;
30-
return result;
31-
}
23+
public:
24+
UniquePtr(T *ptr) : ptr(ptr) {}
25+
~UniquePtr() { delete ptr; }
26+
UniquePtr(UniquePtr &&other) : ptr(other.ptr) { other.ptr = nullptr; }
27+
UniquePtr &operator=(UniquePtr &&other) {
28+
delete ptr;
29+
ptr = other.ptr;
30+
other.ptr = nullptr;
31+
return *this;
32+
}
33+
T *operator->() { return ptr; }
34+
template <typename... U> static UniquePtr create(U &&...x) {
35+
AllocChecker ac;
36+
T *ptr = new (ac) T(cpp::forward<U>(x)...);
37+
if (!ac)
38+
return {nullptr};
39+
return UniquePtr(ptr);
40+
}
41+
operator bool() { return ptr != nullptr; }
42+
T *get() { return ptr; }
43+
};
3244

33-
static cpp::optional<uint64_t> next_uint64() {
34-
uint64_t result;
35-
if (remaining < sizeof(result))
36-
return cpp::nullopt;
37-
memcpy(&result, global_buffer, sizeof(result));
38-
global_buffer += sizeof(result);
39-
remaining -= sizeof(result);
40-
return result;
41-
}
45+
// a tagged union
46+
struct Action {
47+
enum class Tag { Find, Insert, CrossCheck } tag;
48+
cpp::string key;
49+
UniquePtr<Action> next;
50+
Action(Tag tag, cpp::string key, UniquePtr<Action> next)
51+
: tag(tag), key(cpp::move(key)), next(cpp::move(next)) {}
52+
};
4253

43-
static cpp::optional<Action> next_action() {
44-
if (cpp::optional<uint8_t> action = next_u8()) {
45-
switch (*action % 3) {
46-
case 0:
47-
return Action::Find;
48-
case 1:
49-
return Action::Insert;
50-
case 2:
51-
return Action::CrossCheck;
52-
}
54+
static struct {
55+
UniquePtr<Action> actions = nullptr;
56+
size_t remaining;
57+
const char *buffer;
58+
59+
template <typename T> cpp::optional<T> next() {
60+
static_assert(cpp::is_integral<T>::value, "T must be an integral type");
61+
union {
62+
T result;
63+
char data[sizeof(T)];
64+
};
65+
if (remaining < sizeof(result))
66+
return cpp::nullopt;
67+
for (size_t i = 0; i < sizeof(result); i++)
68+
data[i] = buffer[i];
69+
buffer += sizeof(result);
70+
remaining -= sizeof(result);
71+
return result;
5372
}
54-
return cpp::nullopt;
55-
}
5673

57-
static cpp::optional<char *> next_cstr() {
58-
char *result = reinterpret_cast<char *>(global_buffer);
59-
if (cpp::optional<uint8_t> len = next_u8()) {
60-
uint64_t length;
61-
for (length = 0; length < *len; length++) {
62-
if (length >= remaining)
63-
return cpp::nullopt;
64-
if (*global_buffer == '\0')
74+
cpp::optional<cpp::string> next_string() {
75+
if (cpp::optional<uint16_t> len = next<uint16_t>()) {
76+
uint64_t length;
77+
for (length = 0; length < *len && length < remaining; length++)
78+
if (buffer[length] == '\0')
79+
break;
80+
cpp::string result(buffer, length);
81+
result += '\0';
82+
buffer += length;
83+
remaining -= length;
84+
return result;
85+
}
86+
return cpp::nullopt;
87+
}
88+
Action *next_action() {
89+
if (cpp::optional<uint8_t> action = next<uint8_t>()) {
90+
switch (*action % 3) {
91+
case 0: {
92+
if (cpp::optional<cpp::string> key = next_string())
93+
actions = UniquePtr<Action>::create(
94+
Action::Tag::Find, cpp::move(*key), cpp::move(actions));
95+
else
96+
return nullptr;
6597
break;
98+
}
99+
case 1: {
100+
if (cpp::optional<cpp::string> key = next_string())
101+
actions = UniquePtr<Action>::create(
102+
Action::Tag::Insert, cpp::move(*key), cpp::move(actions));
103+
else
104+
return nullptr;
105+
break;
106+
}
107+
case 2: {
108+
actions = UniquePtr<Action>::create(Action::Tag::CrossCheck, "",
109+
cpp::move(actions));
110+
break;
111+
}
112+
}
113+
return actions.get();
66114
}
67-
if (length >= remaining)
68-
return cpp::nullopt;
69-
global_buffer[length] = '\0';
70-
global_buffer += length + 1;
71-
remaining -= length + 1;
72-
return result;
115+
return nullptr;
73116
}
74-
return cpp::nullopt;
75-
}
117+
} global_status;
76118

77-
#define get_value(op) \
78-
__extension__({ \
79-
auto val = op(); \
80-
if (!val) \
81-
return 0; \
82-
*val; \
83-
})
119+
class HashTable {
120+
internal::HashTable *table;
84121

85-
template <typename Fn> struct CleanUpHook {
86-
cpp::optional<Fn> fn;
87-
~CleanUpHook() {
88-
if (fn)
89-
(*fn)();
90-
}
91-
CleanUpHook(Fn fn) : fn(cpp::move(fn)) {}
92-
CleanUpHook(const CleanUpHook &) = delete;
93-
CleanUpHook(CleanUpHook &&other) : fn(cpp::move(other.fn)) {
94-
other.fn = cpp::nullopt;
122+
public:
123+
HashTable(uint64_t size, uint64_t seed)
124+
: table(internal::HashTable::allocate(size, seed)) {}
125+
HashTable(internal::HashTable *table) : table(table) {}
126+
~HashTable() { internal::HashTable::deallocate(table); }
127+
HashTable(HashTable &&other) : table(other.table) { other.table = nullptr; }
128+
bool is_valid() const { return table != nullptr; }
129+
ENTRY *find(const char *key) { return table->find(key); }
130+
ENTRY *insert(const ENTRY &entry) {
131+
return internal::HashTable::insert(this->table, entry);
95132
}
133+
using iterator = internal::HashTable::iterator;
134+
iterator begin() const { return table->begin(); }
135+
iterator end() const { return table->end(); }
96136
};
97137

98-
#define register_cleanup(ID, ...) \
99-
auto cleanup_hook##ID = __extension__({ \
100-
auto a = __VA_ARGS__; \
101-
CleanUpHook<decltype(a)>(cpp::move(a)); \
102-
});
138+
HashTable next_hashtable() {
139+
if (cpp::optional<uint16_t> size = global_status.next<uint16_t>())
140+
if (cpp::optional<uint64_t> seed = global_status.next<uint64_t>())
141+
return HashTable(*size, *seed);
103142

104-
static void trap_with_message(const char *msg) { __builtin_trap(); }
143+
return HashTable(0, 0);
144+
}
105145

106146
extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
107-
AllocChecker ac;
108-
global_buffer = static_cast<uint8_t *>(::operator new(size, ac));
109-
register_cleanup(0, [global_buffer = global_buffer, size] {
110-
::operator delete(global_buffer, size);
111-
});
112-
if (!ac)
147+
char key[] = "key";
148+
global_status.buffer = reinterpret_cast<const char *>(data);
149+
global_status.remaining = size;
150+
HashTable table_a = next_hashtable();
151+
HashTable table_b = next_hashtable();
152+
if (!table_a.is_valid() || !table_b.is_valid())
113153
return 0;
114-
memcpy(global_buffer, data, size);
115154

116-
remaining = size;
117-
uint64_t size_a = get_value(next_u8);
118-
uint64_t size_b = get_value(next_u8);
119-
uint64_t rand_a = get_value(next_uint64);
120-
uint64_t rand_b = get_value(next_uint64);
121-
internal::HashTable *table_a = internal::HashTable::allocate(size_a, rand_a);
122-
register_cleanup(1, [&table_a] {
123-
if (table_a)
124-
internal::HashTable::deallocate(table_a);
125-
});
126-
internal::HashTable *table_b = internal::HashTable::allocate(size_b, rand_b);
127-
register_cleanup(2, [&table_b] {
128-
if (table_b)
129-
internal::HashTable::deallocate(table_b);
130-
});
131-
if (!table_a || !table_b)
132-
return 0;
133155
for (;;) {
134-
Action action = get_value(next_action);
135-
switch (action) {
136-
case Action::Find: {
137-
const char *key = get_value(next_cstr);
138-
if (static_cast<bool>(table_a->find(key)) !=
139-
static_cast<bool>(table_b->find(key)))
140-
trap_with_message(key);
156+
Action *action = global_status.next_action();
157+
if (!action)
158+
return 0;
159+
switch (action->tag) {
160+
case Action::Tag::Find: {
161+
if (table_a.find(action->key.c_str()) !=
162+
table_b.find(action->key.c_str()))
163+
__builtin_trap();
141164
break;
142165
}
143-
case Action::Insert: {
144-
char *key = get_value(next_cstr);
145-
ENTRY *a = internal::HashTable::insert(table_a, ENTRY{key, key});
146-
ENTRY *b = internal::HashTable::insert(table_b, ENTRY{key, key});
166+
case Action::Tag::Insert: {
167+
ENTRY *a = table_a.insert(ENTRY{key, key});
168+
ENTRY *b = table_b.insert(ENTRY{key, key});
147169
if (a->data != b->data)
148170
__builtin_trap();
149171
break;
150172
}
151-
case Action::CrossCheck: {
152-
for (ENTRY a : *table_a) {
153-
if (const ENTRY *b = table_b->find(a.key)) {
154-
if (a.data != b->data)
155-
__builtin_trap();
156-
}
157-
}
158-
for (ENTRY b : *table_b) {
159-
if (const ENTRY *a = table_a->find(b.key)) {
160-
if (a->data != b.data)
161-
__builtin_trap();
162-
}
163-
}
173+
case Action::Tag::CrossCheck: {
174+
for (ENTRY a : table_a)
175+
if (const ENTRY *b = table_b.find(a.key); a.data != b->data)
176+
__builtin_trap();
177+
178+
for (ENTRY b : table_b)
179+
if (const ENTRY *a = table_a.find(b.key); a->data != b.data)
180+
__builtin_trap();
181+
164182
break;
165183
}
166184
}
167185
}
186+
return 0;
168187
}
169188

170189
} // namespace LIBC_NAMESPACE

libc/src/__support/HashTable/generic/bitmask_impl.inc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@ LIBC_INLINE constexpr bitmask_t repeat_byte(bitmask_t byte) {
3434
return byte;
3535
}
3636

37-
using BitMask = BitMaskAdaptor<bitmask_t, 0x8ull>;
37+
using BitMask = BitMaskAdaptor<bitmask_t, 0x8ul>;
3838
using IteratableBitMask = IteratableBitMaskAdaptor<BitMask>;
3939

4040
struct Group {
41+
LIBC_INLINE_VAR static constexpr bitmask_t MASK = repeat_byte(0x80ul);
4142
bitmask_t data;
4243

4344
// Load a group of control words from an arbitary address.
@@ -100,21 +101,23 @@ struct Group {
100101
// - The check for key equality will catch these.
101102
// - This only happens if there is at least 1 true match.
102103
// - The chance of this happening is very low (< 1% chance per byte).
104+
static constexpr bitmask_t ONES = repeat_byte(0x01ul);
103105
auto cmp = data ^ repeat_byte(byte);
104-
auto result = LIBC_NAMESPACE::Endian::to_little_endian(
105-
(cmp - repeat_byte(0x01)) & ~cmp & repeat_byte(0x80));
106+
auto result =
107+
LIBC_NAMESPACE::Endian::to_little_endian((cmp - ONES) & ~cmp & MASK);
106108
return {BitMask{result}};
107109
}
108110

109111
// Find out the lanes equal to EMPTY or DELETE (highest bit set) and
110112
// return the bitmask with corresponding bits set.
111113
LIBC_INLINE BitMask mask_available() const {
112-
return {LIBC_NAMESPACE::Endian::to_little_endian(data) & repeat_byte(0x80)};
114+
bitmask_t le_data = LIBC_NAMESPACE::Endian::to_little_endian(data);
115+
return {le_data & MASK};
113116
}
114117

115118
LIBC_INLINE IteratableBitMask occupied() const {
116-
return {
117-
{static_cast<bitmask_t>(mask_available().word ^ repeat_byte(0x80))}};
119+
bitmask_t available = mask_available().word;
120+
return {BitMask{available ^ MASK}};
118121
}
119122
};
120123
} // namespace internal

0 commit comments

Comments
 (0)