Skip to content

Commit 1eaccab

Browse files
optimize initialization checks
1 parent c9599c4 commit 1eaccab

File tree

1 file changed

+102
-9
lines changed

1 file changed

+102
-9
lines changed

compiler/rustc_middle/src/mir/interpret/allocation.rs

Lines changed: 102 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! The virtual memory representation of the MIR interpreter.
22
33
use std::borrow::Cow;
4-
use std::convert::TryFrom;
4+
use std::convert::{TryFrom, TryInto};
55
use std::iter;
66
use std::ops::{Deref, Range};
77
use std::ptr;
@@ -720,13 +720,12 @@ impl InitMask {
720720
return Err(self.len..end);
721721
}
722722

723-
// FIXME(oli-obk): optimize this for allocations larger than a block.
724-
let idx = (start..end).find(|&i| !self.get(i));
723+
let uninit_start = find_bit(self, start, end, false);
725724

726-
match idx {
727-
Some(idx) => {
728-
let uninit_end = (idx..end).find(|&i| self.get(i)).unwrap_or(end);
729-
Err(idx..uninit_end)
725+
match uninit_start {
726+
Some(uninit_start) => {
727+
let uninit_end = find_bit(self, uninit_start, end, true).unwrap_or(end);
728+
Err(uninit_start..uninit_end)
730729
}
731730
None => Ok(()),
732731
}
@@ -863,9 +862,8 @@ impl<'a> Iterator for InitChunkIter<'a> {
863862
}
864863

865864
let is_init = self.init_mask.get(self.start);
866-
// FIXME(oli-obk): optimize this for allocations larger than a block.
867865
let end_of_chunk =
868-
(self.start..self.end).find(|&i| self.init_mask.get(i) != is_init).unwrap_or(self.end);
866+
find_bit(&self.init_mask, self.start, self.end, !is_init).unwrap_or(self.end);
869867
let range = self.start..end_of_chunk;
870868

871869
self.start = end_of_chunk;
@@ -874,10 +872,105 @@ impl<'a> Iterator for InitChunkIter<'a> {
874872
}
875873
}
876874

875+
/// Returns the index of the first bit in `start..end` (end-exclusive) that is equal to is_init.
876+
fn find_bit(init_mask: &InitMask, start: Size, end: Size, is_init: bool) -> Option<Size> {
877+
fn find_bit_fast(init_mask: &InitMask, start: Size, end: Size, is_init: bool) -> Option<Size> {
878+
fn search_block(
879+
bits: Block,
880+
block: usize,
881+
start_bit: usize,
882+
is_init: bool,
883+
) -> Option<Size> {
884+
// invert bits so we're always looking for the first set bit
885+
let bits = if is_init { bits } else { !bits };
886+
// mask off unused start bits
887+
let bits = bits & (!0 << start_bit);
888+
// find set bit, if any
889+
if bits == 0 {
890+
None
891+
} else {
892+
let bit = bits.trailing_zeros();
893+
Some(size_from_bit_index(block, bit))
894+
}
895+
}
896+
897+
if start >= end {
898+
return None;
899+
}
900+
901+
let (start_block, start_bit) = bit_index(start);
902+
let (end_block, end_bit) = bit_index(end);
903+
904+
// handle first block: need to skip `start_bit` bits
905+
if let Some(i) =
906+
search_block(init_mask.blocks[start_block], start_block, start_bit, is_init)
907+
{
908+
if i < end {
909+
return Some(i);
910+
} else {
911+
// if the range is less than a block, we may find a matching bit after `end`
912+
return None;
913+
}
914+
}
915+
916+
let one_block_past_the_end = if end_bit > 0 {
917+
// if `end_bit` > 0, then the range overlaps `end_block`
918+
end_block + 1
919+
} else {
920+
end_block
921+
};
922+
923+
// handle remaining blocks
924+
if start_block < one_block_past_the_end {
925+
for (&bits, block) in init_mask.blocks[start_block + 1..one_block_past_the_end]
926+
.iter()
927+
.zip(start_block + 1..)
928+
{
929+
if let Some(i) = search_block(bits, block, 0, is_init) {
930+
if i < end {
931+
return Some(i);
932+
} else {
933+
// if this is the last block, we may find a matching bit after `end`
934+
return None;
935+
}
936+
}
937+
}
938+
}
939+
940+
None
941+
}
942+
943+
#[cfg_attr(not(debug_assertions), allow(dead_code))]
944+
fn find_bit_slow(init_mask: &InitMask, start: Size, end: Size, is_init: bool) -> Option<Size> {
945+
(start..end).find(|&i| init_mask.get(i) == is_init)
946+
}
947+
948+
let result = find_bit_fast(init_mask, start, end, is_init);
949+
950+
debug_assert_eq!(
951+
result,
952+
find_bit_slow(init_mask, start, end, is_init),
953+
"optimized implementation of find_bit is wrong for start={:?} end={:?} is_init={} init_mask={:#?}",
954+
start,
955+
end,
956+
is_init,
957+
init_mask
958+
);
959+
960+
result
961+
}
962+
877963
#[inline]
878964
fn bit_index(bits: Size) -> (usize, usize) {
879965
let bits = bits.bytes();
880966
let a = bits / InitMask::BLOCK_SIZE;
881967
let b = bits % InitMask::BLOCK_SIZE;
882968
(usize::try_from(a).unwrap(), usize::try_from(b).unwrap())
883969
}
970+
971+
#[inline]
972+
fn size_from_bit_index(block: impl TryInto<u64>, bit: impl TryInto<u64>) -> Size {
973+
let block = block.try_into().ok().unwrap();
974+
let bit = bit.try_into().ok().unwrap();
975+
Size::from_bytes(block * InitMask::BLOCK_SIZE + bit)
976+
}

0 commit comments

Comments
 (0)