Skip to content

Commit 5448de7

Browse files
committed
generalize bitvector code into a bitmatrix; write some unit tests, but
probably not enough. This code is so simple, what could possibly go wrong?
1 parent 6c11e4a commit 5448de7

File tree

1 file changed

+176
-9
lines changed

1 file changed

+176
-9
lines changed

src/librustc_data_structures/bitvec.rs

Lines changed: 176 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,193 @@ pub struct BitVector {
1515

1616
impl BitVector {
1717
pub fn new(num_bits: usize) -> BitVector {
18-
let num_words = (num_bits + 63) / 64;
18+
let num_words = u64s(num_bits);
1919
BitVector { data: vec![0; num_words] }
2020
}
2121

22-
fn word_mask(&self, bit: usize) -> (usize, u64) {
23-
let word = bit / 64;
24-
let mask = 1 << (bit % 64);
25-
(word, mask)
26-
}
27-
2822
pub fn contains(&self, bit: usize) -> bool {
29-
let (word, mask) = self.word_mask(bit);
23+
let (word, mask) = word_mask(bit);
3024
(self.data[word] & mask) != 0
3125
}
3226

3327
pub fn insert(&mut self, bit: usize) -> bool {
34-
let (word, mask) = self.word_mask(bit);
28+
let (word, mask) = word_mask(bit);
3529
let data = &mut self.data[word];
3630
let value = *data;
3731
*data = value | mask;
3832
(value | mask) != value
3933
}
34+
35+
pub fn insert_all(&mut self, all: &BitVector) -> bool {
36+
assert!(self.data.len() == all.data.len());
37+
let mut changed = false;
38+
for (i, j) in self.data.iter_mut().zip(&all.data) {
39+
let value = *i;
40+
*i = value | *j;
41+
if value != *i { changed = true; }
42+
}
43+
changed
44+
}
45+
46+
pub fn grow(&mut self, num_bits: usize) {
47+
let num_words = u64s(num_bits);
48+
let extra_words = self.data.len() - num_words;
49+
self.data.extend((0..extra_words).map(|_| 0));
50+
}
51+
}
52+
53+
/// A "bit matrix" is basically a square matrix of booleans
54+
/// represented as one gigantic bitvector. In other words, it is as if
55+
/// you have N bitvectors, each of length N.
56+
#[derive(Clone)]
57+
pub struct BitMatrix {
58+
elements: usize,
59+
vector: Vec<u64>,
60+
}
61+
62+
impl BitMatrix {
63+
pub fn new(elements: usize) -> BitMatrix {
64+
// For every element, we need one bit for every other
65+
// element. Round up to an even number of u64s.
66+
let u64s_per_elem = u64s(elements);
67+
BitMatrix {
68+
elements: elements,
69+
vector: vec![0; elements * u64s_per_elem]
70+
}
71+
}
72+
73+
/// The range of bits for a given element.
74+
fn range(&self, element: usize) -> (usize, usize) {
75+
let u64s_per_elem = u64s(self.elements);
76+
let start = element * u64s_per_elem;
77+
(start, start + u64s_per_elem)
78+
}
79+
80+
pub fn add(&mut self, source: usize, target: usize) -> bool {
81+
let (start, _) = self.range(source);
82+
let (word, mask) = word_mask(target);
83+
let mut vector = &mut self.vector[..];
84+
let v1 = vector[start+word];
85+
let v2 = v1 | mask;
86+
vector[start+word] = v2;
87+
v1 != v2
88+
}
89+
90+
/// Do the bits from `source` contain `target`?
91+
/// Put another way, can `source` reach `target`?
92+
pub fn contains(&self, source: usize, target: usize) -> bool {
93+
let (start, _) = self.range(source);
94+
let (word, mask) = word_mask(target);
95+
(self.vector[start+word] & mask) != 0
96+
}
97+
98+
/// Returns those indices that are reachable from both source and
99+
/// target. This is an O(n) operation where `n` is the number of
100+
/// elements (somewhat independent from the actual size of the
101+
/// intersection, in particular).
102+
pub fn intersection(&self, a: usize, b: usize) -> Vec<usize> {
103+
let (a_start, a_end) = self.range(a);
104+
let (b_start, b_end) = self.range(b);
105+
let mut result = Vec::with_capacity(self.elements);
106+
for (base, (i, j)) in (a_start..a_end).zip(b_start..b_end).enumerate() {
107+
let mut v = self.vector[i] & self.vector[j];
108+
for bit in 0..64 {
109+
if v == 0 { break; }
110+
if v & 0x1 != 0 { result.push(base*64 + bit); }
111+
v >>= 1;
112+
}
113+
}
114+
result
115+
}
116+
117+
/// Add the bits from source to the bits from destination,
118+
/// return true if anything changed.
119+
///
120+
/// This is used when computing reachability because if you have
121+
/// an edge `destination -> source`, because in that case
122+
/// `destination` can reach everything that `source` can (and
123+
/// potentially more).
124+
pub fn merge(&mut self, source: usize, destination: usize) -> bool {
125+
let (source_start, source_end) = self.range(source);
126+
let (destination_start, destination_end) = self.range(destination);
127+
let vector = &mut self.vector[..];
128+
let mut changed = false;
129+
for (source_index, destination_index) in
130+
(source_start..source_end).zip(destination_start..destination_end)
131+
{
132+
let v1 = vector[destination_index];
133+
let v2 = v1 | vector[source_index];
134+
vector[destination_index] = v2;
135+
changed = changed | (v1 != v2);
136+
}
137+
changed
138+
}
139+
}
140+
141+
fn u64s(elements: usize) -> usize {
142+
(elements + 63) / 64
143+
}
144+
145+
fn word_mask(index: usize) -> (usize, u64) {
146+
let word = index / 64;
147+
let mask = 1 << (index % 64);
148+
(word, mask)
149+
}
150+
151+
#[test]
152+
fn union_two_vecs() {
153+
let mut vec1 = BitVector::new(65);
154+
let mut vec2 = BitVector::new(65);
155+
assert!(vec1.insert(3));
156+
assert!(!vec1.insert(3));
157+
assert!(vec2.insert(5));
158+
assert!(vec2.insert(64));
159+
assert!(vec1.insert_all(&vec2));
160+
assert!(!vec1.insert_all(&vec2));
161+
assert!(vec1.contains(3));
162+
assert!(!vec1.contains(4));
163+
assert!(vec1.contains(5));
164+
assert!(!vec1.contains(63));
165+
assert!(vec1.contains(64));
166+
}
167+
168+
#[test]
169+
fn grow() {
170+
let mut vec1 = BitVector::new(65);
171+
assert!(vec1.insert(3));
172+
assert!(!vec1.insert(3));
173+
assert!(vec1.insert(5));
174+
assert!(vec1.insert(64));
175+
vec1.grow(128);
176+
assert!(vec1.contains(3));
177+
assert!(vec1.contains(5));
178+
assert!(vec1.contains(64));
179+
assert!(!vec1.contains(126));
180+
}
181+
182+
#[test]
183+
fn matrix_intersection() {
184+
let mut vec1 = BitMatrix::new(200);
185+
186+
vec1.add(2, 3);
187+
vec1.add(2, 6);
188+
vec1.add(2, 10);
189+
vec1.add(2, 64);
190+
vec1.add(2, 65);
191+
vec1.add(2, 130);
192+
vec1.add(2, 160);
193+
194+
vec1.add(65, 2);
195+
vec1.add(65, 8);
196+
vec1.add(65, 10); // X
197+
vec1.add(65, 64); // X
198+
vec1.add(65, 68);
199+
vec1.add(65, 133);
200+
vec1.add(65, 160); // X
201+
202+
let intersection = vec1.intersection(2, 64);
203+
assert!(intersection.is_empty());
204+
205+
let intersection = vec1.intersection(2, 65);
206+
assert_eq!(intersection, vec![10, 64, 160]);
40207
}

0 commit comments

Comments
 (0)