Skip to content

Commit 76c49ae

Browse files
committed
support non-null pointer niches in CTFE
1 parent 3c05276 commit 76c49ae

File tree

8 files changed

+100
-65
lines changed

8 files changed

+100
-65
lines changed

compiler/rustc_abi/src/lib.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,43 @@ impl WrappingRange {
10061006
}
10071007
}
10081008

1009+
/// Returns `true` if `range` is contained in `self`.
1010+
#[inline(always)]
1011+
pub fn contains_range<I: Into<u128> + Ord>(&self, range: RangeInclusive<I>) -> bool {
1012+
if range.is_empty() {
1013+
return true;
1014+
}
1015+
1016+
let (vmin, vmax) = range.into_inner();
1017+
let (vmin, vmax) = (vmin.into(), vmax.into());
1018+
1019+
if self.start <= self.end {
1020+
self.start <= vmin && vmax <= self.end
1021+
} else {
1022+
// The last check is needed to cover the following case:
1023+
// `vmin ... start, end ... vmax`. In this special case there is no gap
1024+
// between `start` and `end` so we must return true.
1025+
self.start <= vmin || vmax <= self.end || self.start == self.end + 1
1026+
}
1027+
}
1028+
1029+
/// Returns `true` if `range` has an overlap with `self`.
1030+
#[inline(always)]
1031+
pub fn overlaps_range<I: Into<u128> + Ord>(&self, range: RangeInclusive<I>) -> bool {
1032+
if range.is_empty() {
1033+
return false;
1034+
}
1035+
1036+
let (vmin, vmax) = range.into_inner();
1037+
let (vmin, vmax) = (vmin.into(), vmax.into());
1038+
1039+
if self.start <= self.end {
1040+
self.start <= vmax && vmin <= self.end
1041+
} else {
1042+
self.start <= vmax || vmin <= self.end
1043+
}
1044+
}
1045+
10091046
/// Returns `self` with replaced `start`
10101047
#[inline(always)]
10111048
pub fn with_start(mut self, start: u128) -> Self {

compiler/rustc_const_eval/messages.ftl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,6 @@ const_eval_not_enough_caller_args =
244244
const_eval_null_box = {$front_matter}: encountered a null box
245245
const_eval_null_fn_ptr = {$front_matter}: encountered a null function pointer
246246
const_eval_null_ref = {$front_matter}: encountered a null reference
247-
const_eval_nullable_ptr_out_of_range = {$front_matter}: encountered a potentially null pointer, but expected something that cannot possibly fail to be {$in_range}
248247
const_eval_nullary_intrinsic_fail =
249248
could not evaluate nullary intrinsic
250249

compiler/rustc_const_eval/src/const_eval/machine.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ impl<'mir, 'tcx: 'mir> CompileTimeEvalContext<'mir, 'tcx> {
333333
// Inequality with integers other than null can never be known for sure.
334334
(Scalar::Int(int), ptr @ Scalar::Ptr(..))
335335
| (ptr @ Scalar::Ptr(..), Scalar::Int(int))
336-
if int.is_null() && !self.scalar_may_be_null(ptr)? =>
336+
if int.is_null() && !self.ptr_scalar_range(ptr)?.contains(&0) =>
337337
{
338338
0
339339
}

compiler/rustc_const_eval/src/errors.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,6 @@ impl<'tcx> ReportErrorExt for ValidationErrorInfo<'tcx> {
617617
MutableRefInConst => const_eval_mutable_ref_in_const,
618618
NullFnPtr => const_eval_null_fn_ptr,
619619
NeverVal => const_eval_never_val,
620-
NullablePtrOutOfRange { .. } => const_eval_nullable_ptr_out_of_range,
621620
PtrOutOfRange { .. } => const_eval_ptr_out_of_range,
622621
OutOfRange { .. } => const_eval_out_of_range,
623622
UnsafeCell => const_eval_unsafe_cell,
@@ -732,9 +731,7 @@ impl<'tcx> ReportErrorExt for ValidationErrorInfo<'tcx> {
732731
| InvalidFnPtr { value } => {
733732
err.set_arg("value", value);
734733
}
735-
NullablePtrOutOfRange { range, max_value } | PtrOutOfRange { range, max_value } => {
736-
add_range_arg(range, max_value, handler, err)
737-
}
734+
PtrOutOfRange { range, max_value } => add_range_arg(range, max_value, handler, err),
738735
OutOfRange { range, max_value, value } => {
739736
err.set_arg("value", value);
740737
add_range_arg(range, max_value, handler, err);

compiler/rustc_const_eval/src/interpret/discriminant.rs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
33
use rustc_middle::ty::layout::{LayoutOf, PrimitiveExt};
44
use rustc_middle::{mir, ty};
5-
use rustc_target::abi::{self, TagEncoding};
6-
use rustc_target::abi::{VariantIdx, Variants};
5+
use rustc_target::abi::{self, TagEncoding, VariantIdx, Variants, WrappingRange};
76

87
use super::{ImmTy, InterpCx, InterpResult, Machine, OpTy, PlaceTy, Scalar};
98

@@ -180,19 +179,24 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
180179
// discriminant (encoded in niche/tag) and variant index are the same.
181180
let variants_start = niche_variants.start().as_u32();
182181
let variants_end = niche_variants.end().as_u32();
182+
let variants_len = u128::from(variants_end - variants_start);
183183
let variant = match tag_val.try_to_int() {
184184
Err(dbg_val) => {
185185
// So this is a pointer then, and casting to an int failed.
186186
// Can only happen during CTFE.
187-
// The niche must be just 0, and the ptr not null, then we know this is
188-
// okay. Everything else, we conservatively reject.
189-
let ptr_valid = niche_start == 0
190-
&& variants_start == variants_end
191-
&& !self.scalar_may_be_null(tag_val)?;
192-
if !ptr_valid {
187+
// The pointer and niches ranges must be disjoint, then we know
188+
// this is the untagged variant (as the value is not in the niche).
189+
// Everything else, we conservatively reject.
190+
let range = self.ptr_scalar_range(tag_val)?;
191+
let niches = WrappingRange {
192+
start: niche_start,
193+
end: niche_start.wrapping_add(variants_len),
194+
};
195+
if niches.overlaps_range(range) {
193196
throw_ub!(InvalidTag(dbg_val))
197+
} else {
198+
untagged_variant
194199
}
195-
untagged_variant
196200
}
197201
Ok(tag_bits) => {
198202
let tag_bits = tag_bits.assert_bits(tag_layout.size);
@@ -205,7 +209,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
205209
let variant_index_relative =
206210
variant_index_relative_val.to_scalar().assert_bits(tag_val.layout.size);
207211
// Check if this is in the range that indicates an actual discriminant.
208-
if variant_index_relative <= u128::from(variants_end - variants_start) {
212+
if variant_index_relative <= variants_len {
209213
let variant_index_relative = u32::try_from(variant_index_relative)
210214
.expect("we checked that this fits into a u32");
211215
// Then computing the absolute variant idx should not overflow any more.

compiler/rustc_const_eval/src/interpret/memory.rs

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::assert_matches::assert_matches;
1010
use std::borrow::Cow;
1111
use std::collections::VecDeque;
1212
use std::fmt;
13+
use std::ops::RangeInclusive;
1314
use std::ptr;
1415

1516
use rustc_ast::Mutability;
@@ -1222,24 +1223,34 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
12221223

12231224
/// Machine pointer introspection.
12241225
impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
1225-
/// Test if this value might be null.
1226+
/// Turn a pointer-sized scalar into a (non-empty) range of possible values.
12261227
/// If the machine does not support ptr-to-int casts, this is conservative.
1227-
pub fn scalar_may_be_null(&self, scalar: Scalar<M::Provenance>) -> InterpResult<'tcx, bool> {
1228-
Ok(match scalar.try_to_int() {
1229-
Ok(int) => int.is_null(),
1230-
Err(_) => {
1231-
// Can only happen during CTFE.
1232-
let ptr = scalar.to_pointer(self)?;
1233-
match self.ptr_try_get_alloc_id(ptr) {
1234-
Ok((alloc_id, offset, _)) => {
1235-
let (size, _align, _kind) = self.get_alloc_info(alloc_id);
1236-
// If the pointer is out-of-bounds, it may be null.
1237-
// Note that one-past-the-end (offset == size) is still inbounds, and never null.
1238-
offset > size
1239-
}
1240-
Err(_offset) => bug!("a non-int scalar is always a pointer"),
1228+
pub fn ptr_scalar_range(
1229+
&self,
1230+
scalar: Scalar<M::Provenance>,
1231+
) -> InterpResult<'tcx, RangeInclusive<u64>> {
1232+
if let Ok(int) = scalar.to_target_usize(self) {
1233+
return Ok(int..=int);
1234+
}
1235+
1236+
let ptr = scalar.to_pointer(self)?;
1237+
1238+
// Can only happen during CTFE.
1239+
Ok(match self.ptr_try_get_alloc_id(ptr) {
1240+
Ok((alloc_id, offset, _)) => {
1241+
let offset = offset.bytes();
1242+
let (size, align, _) = self.get_alloc_info(alloc_id);
1243+
let dl = self.data_layout();
1244+
if offset > size.bytes() {
1245+
// If the pointer is out-of-bounds, we do not have a
1246+
// meaningful range to return.
1247+
0..=dl.max_address()
1248+
} else {
1249+
let (min, max) = dl.address_range_for(size, align);
1250+
(min + offset)..=(max + offset)
12411251
}
12421252
}
1253+
Err(_offset) => bug!("a non-int scalar is always a pointer"),
12431254
})
12441255
}
12451256

compiler/rustc_const_eval/src/interpret/validity.rs

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ use rustc_middle::mir::interpret::{
1919
use rustc_middle::ty;
2020
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
2121
use rustc_span::symbol::{sym, Symbol};
22-
use rustc_target::abi::{
23-
Abi, FieldIdx, Scalar as ScalarAbi, Size, VariantIdx, Variants, WrappingRange,
24-
};
22+
use rustc_target::abi::{Abi, FieldIdx, Scalar as ScalarAbi, Size, VariantIdx, Variants};
2523

2624
use std::hash::Hash;
2725

@@ -554,7 +552,7 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, '
554552
// FIXME: Check if the signature matches
555553
} else {
556554
// Otherwise (for standalone Miri), we have to still check it to be non-null.
557-
if self.ecx.scalar_may_be_null(value)? {
555+
if self.ecx.ptr_scalar_range(value)?.contains(&0) {
558556
throw_validation_failure!(self.path, NullFnPtr);
559557
}
560558
}
@@ -595,46 +593,36 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, '
595593
) -> InterpResult<'tcx> {
596594
let size = scalar_layout.size(self.ecx);
597595
let valid_range = scalar_layout.valid_range(self.ecx);
598-
let WrappingRange { start, end } = valid_range;
599596
let max_value = size.unsigned_int_max();
600-
assert!(end <= max_value);
601-
let bits = match scalar.try_to_int() {
602-
Ok(int) => int.assert_bits(size),
597+
assert!(valid_range.end <= max_value);
598+
match scalar.try_to_int() {
599+
Ok(int) => {
600+
// We have an explicit int: check it against the valid range.
601+
let bits = int.assert_bits(size);
602+
if valid_range.contains(bits) {
603+
Ok(())
604+
} else {
605+
throw_validation_failure!(
606+
self.path,
607+
OutOfRange { value: format!("{bits}"), range: valid_range, max_value }
608+
)
609+
}
610+
}
603611
Err(_) => {
604612
// So this is a pointer then, and casting to an int failed.
605613
// Can only happen during CTFE.
606-
// We support 2 kinds of ranges here: full range, and excluding zero.
607-
if start == 1 && end == max_value {
608-
// Only null is the niche. So make sure the ptr is NOT null.
609-
if self.ecx.scalar_may_be_null(scalar)? {
610-
throw_validation_failure!(
611-
self.path,
612-
NullablePtrOutOfRange { range: valid_range, max_value }
613-
)
614-
} else {
615-
return Ok(());
616-
}
617-
} else if scalar_layout.is_always_valid(self.ecx) {
618-
// Easy. (This is reachable if `enforce_number_validity` is set.)
619-
return Ok(());
614+
// We check if the possible addresses are compatible with the valid range.
615+
let range = self.ecx.ptr_scalar_range(scalar)?;
616+
if valid_range.contains_range(range) {
617+
Ok(())
620618
} else {
621-
// Conservatively, we reject, because the pointer *could* have a bad
622-
// value.
619+
// Reject conservatively, because the pointer *could* have a bad value.
623620
throw_validation_failure!(
624621
self.path,
625622
PtrOutOfRange { range: valid_range, max_value }
626623
)
627624
}
628625
}
629-
};
630-
// Now compare.
631-
if valid_range.contains(bits) {
632-
Ok(())
633-
} else {
634-
throw_validation_failure!(
635-
self.path,
636-
OutOfRange { value: format!("{bits}"), range: valid_range, max_value }
637-
)
638626
}
639627
}
640628
}

compiler/rustc_middle/src/mir/interpret/error.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,6 @@ pub enum ValidationErrorKind<'tcx> {
388388
MutableRefInConst,
389389
NullFnPtr,
390390
NeverVal,
391-
NullablePtrOutOfRange { range: WrappingRange, max_value: u128 },
392391
PtrOutOfRange { range: WrappingRange, max_value: u128 },
393392
OutOfRange { value: String, range: WrappingRange, max_value: u128 },
394393
UnsafeCell,

0 commit comments

Comments
 (0)