Skip to content

Commit 849b4e6

Browse files
Fix how we handle niche optimization with ZST (rust-lang#1205)
* Fix how we handle niche optimization with ZST The compiler uses niche optimization when there is only one variant that contains nonzero sized fields. This variant is represented as the dataful variant. However, other variants may have one or more ZST fields that can be deconstructed and referenced in the code. Before this change, we were not generating code for them and it generate type mismatch issues when the user tried to access them. This change ensures that we represent all variants that have any field, so their access is valid. * Enable debug assert inside projection mismatch We should probably turn this into an assert and revert the logic added by model-checking/kani#1057. But this is a much larger change that I prefer creating a separate PR for. Co-authored-by: Zyad Hassan <[email protected]>
1 parent c7c0c4f commit 849b4e6

File tree

7 files changed

+178
-105
lines changed

7 files changed

+178
-105
lines changed

cprover_bindings/src/goto_program/stmt.rs

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use self::StmtBody::*;
44
use super::{BuiltinFn, Expr, Location};
55
use crate::{InternString, InternedString};
66
use std::fmt::Debug;
7-
use tracing::debug;
87

98
///////////////////////////////////////////////////////////////////////////////////////////////
109
/// Datatypes
@@ -166,28 +165,13 @@ macro_rules! stmt {
166165
impl Stmt {
167166
/// `lhs = rhs;`
168167
pub fn assign(lhs: Expr, rhs: Expr, loc: Location) -> Self {
169-
//Temporarily work around https://github.com/model-checking/kani/issues/95
170-
//by disabling the assert and soundly assigning nondet
171-
//assert_eq!(lhs.typ(), rhs.typ());
172-
if lhs.typ() != rhs.typ() {
173-
debug!(
174-
"WARNING: assign statement with unequal types lhs {:?} rhs {:?}",
175-
lhs.typ(),
176-
rhs.typ()
177-
);
178-
let assert_stmt = Stmt::assert_false(
179-
"sanity_check",
180-
&format!(
181-
"Reached assignment statement with unequal types {:?} {:?}",
182-
lhs.typ(),
183-
rhs.typ()
184-
),
185-
loc.clone(),
186-
);
187-
let nondet_value = lhs.typ().nondet();
188-
let nondet_assign_stmt = stmt!(Assign { lhs, rhs: nondet_value }, loc.clone());
189-
return Stmt::block(vec![assert_stmt, nondet_assign_stmt], loc);
190-
}
168+
assert_eq!(
169+
lhs.typ(),
170+
rhs.typ(),
171+
"Error: assign statement with unequal types lhs {:?} rhs {:?}",
172+
lhs.typ(),
173+
rhs.typ()
174+
);
191175
stmt!(Assign { lhs, rhs }, loc)
192176
}
193177

kani-compiler/src/codegen_cprover_gotoc/codegen/place.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,12 @@ impl<'tcx> ProjectedPlace<'tcx> {
169169
if let Some((expr_ty, ty_from_mir)) =
170170
Self::check_expr_typ_mismatch(&goto_expr, &mir_typ_or_variant, ctx)
171171
{
172-
warn!(
172+
let msg = format!(
173173
"Unexpected type mismatch in projection:\n{:?}\nExpr type\n{:?}\nType from MIR\n{:?}",
174174
goto_expr, expr_ty, ty_from_mir
175175
);
176+
warn!("{}", msg);
177+
debug_assert!(false, "{}", msg);
176178
return Err(UnimplementedData::new(
177179
"Projection mismatch",
178180
"https://github.com/model-checking/kani/issues/277",
@@ -499,18 +501,18 @@ impl<'tcx> GotocCtx<'tcx> {
499501
match t.kind() {
500502
ty::Adt(def, _) => {
501503
let variant = def.variants().get(idx).unwrap();
504+
let case_name = variant.name.to_string();
502505
let typ = TypeOrVariant::Variant(variant);
503506
let expr = match &self.layout_of(t).variants {
504507
Variants::Single { .. } => before.goto_expr,
505508
Variants::Multiple { tag_encoding, .. } => match tag_encoding {
506-
TagEncoding::Direct => {
507-
let case_name = variant.name.to_string();
508-
before
509-
.goto_expr
510-
.member("cases", &self.symbol_table)
511-
.member(&case_name, &self.symbol_table)
509+
TagEncoding::Direct => before
510+
.goto_expr
511+
.member("cases", &self.symbol_table)
512+
.member(&case_name, &self.symbol_table),
513+
TagEncoding::Niche { .. } => {
514+
before.goto_expr.member(&case_name, &self.symbol_table)
512515
}
513-
TagEncoding::Niche { .. } => before.goto_expr,
514516
},
515517
};
516518
ProjectedPlace::try_new(

kani-compiler/src/codegen_cprover_gotoc/codegen/typ.rs

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,7 +1084,17 @@ impl<'tcx> GotocCtx<'tcx> {
10841084
Variants::Multiple { tag_encoding, variants, .. } => {
10851085
match tag_encoding {
10861086
TagEncoding::Direct => {
1087-
// direct encoding of tags
1087+
// For direct encoding of tags, we generate a type with two fields:
1088+
// ```
1089+
// struct tag-<> { // enum type
1090+
// case: <discriminant type>,
1091+
// cases: tag-<>-union,
1092+
// }
1093+
// ```
1094+
// The `case` field type determined by the enum representation
1095+
// (`#[repr]`) and it represents which variant is being used.
1096+
// The `cases` field is a union of all variant types where the name
1097+
// of each union field is the name of the corresponding discriminant.
10881098
let discr_t = ctx.codegen_enum_discr_typ(ty);
10891099
let int = ctx.codegen_ty(discr_t);
10901100
let discr_offset = ctx.layout_of(discr_t).size.bits_usize();
@@ -1098,31 +1108,48 @@ impl<'tcx> GotocCtx<'tcx> {
10981108
}
10991109
fields.push(Type::datatype_component(
11001110
"cases",
1101-
ctx.codegen_enum_cases_union(
1102-
name,
1103-
adtdef,
1104-
subst,
1105-
variants,
1106-
initial_offset,
1111+
ctx.ensure_union(
1112+
&format!("{}-union", name.to_string()),
1113+
NO_PRETTY_NAME,
1114+
|ctx, name| {
1115+
ctx.codegen_enum_cases(
1116+
name,
1117+
adtdef,
1118+
subst,
1119+
variants,
1120+
initial_offset,
1121+
)
1122+
},
11071123
),
11081124
));
11091125
fields
11101126
}
11111127
TagEncoding::Niche { dataful_variant, .. } => {
1112-
// niche encoding is an optimization, which uses invalid values for discriminant
1113-
// for example, Option<&i32> becomes just a pointer to i32, and pattern
1114-
// matching becomes checking whether the pointer is null or not. direct
1115-
// encoding, on the other hand, would have been maintaining a field
1116-
// storing the discriminant, which is a few bytes larger.
1128+
// Enumerations with multiple variants and niche encoding have a
1129+
// specific format that can be used to optimize its layout and reduce
1130+
// memory consumption.
11171131
//
1118-
// dataful_variant is pretty much the only variant which contains the valid data
1119-
let variant = &adtdef.variants()[*dataful_variant];
1120-
ctx.codegen_variant_struct_fields(
1121-
variant,
1122-
subst,
1123-
&variants[*dataful_variant],
1124-
0,
1125-
)
1132+
// These enumerations have one and only one variant with non-ZST
1133+
// fields which is referred to by the `dataful_variant` index. Their
1134+
// final size and alignment is equal to the one from the
1135+
// `dataful_variant`. All other variants either don't have any field
1136+
// or all fields types are ZST.
1137+
//
1138+
// Because of that, we can represent these enums as simple structures
1139+
// where each field represent one variant. This allows them to be
1140+
// referred to correctly.
1141+
//
1142+
// Note: I tried using a union instead but it had a significant runtime
1143+
// penalty.
1144+
tracing::trace!(
1145+
?name,
1146+
?variants,
1147+
?dataful_variant,
1148+
?tag_encoding,
1149+
?subst,
1150+
"codegen_enum: Niche"
1151+
);
1152+
ctx.codegen_enum_cases(name, adtdef, subst, variants, 0)
11261153
}
11271154
}
11281155
}
@@ -1211,32 +1238,37 @@ impl<'tcx> GotocCtx<'tcx> {
12111238
}
12121239
}
12131240

1214-
fn codegen_enum_cases_union(
1241+
/// Codegen the type for each variant represented in this enum.
1242+
/// As an optimization, we ignore the ones that don't have any field, since they
1243+
/// are only manipulated via discriminant operations.
1244+
fn codegen_enum_cases(
12151245
&mut self,
12161246
name: InternedString,
12171247
def: &'tcx AdtDef,
12181248
subst: &'tcx InternalSubsts<'tcx>,
12191249
layouts: &IndexVec<VariantIdx, Layout>,
12201250
initial_offset: usize,
1221-
) -> Type {
1222-
// TODO Should we have a pretty name here?
1223-
self.ensure_union(&format!("{}-union", name.to_string()), NO_PRETTY_NAME, |ctx, name| {
1224-
def.variants()
1225-
.iter_enumerated()
1226-
.map(|(i, case)| {
1227-
Type::datatype_component(
1251+
) -> Vec<DatatypeComponent> {
1252+
def.variants()
1253+
.iter_enumerated()
1254+
.filter_map(|(i, case)| {
1255+
if case.fields.is_empty() {
1256+
// Skip variant types that cannot be referenced.
1257+
None
1258+
} else {
1259+
Some(Type::datatype_component(
12281260
&case.name.to_string(),
1229-
ctx.codegen_enum_case_struct(
1261+
self.codegen_enum_case_struct(
12301262
name,
12311263
case,
12321264
subst,
12331265
&layouts[i],
12341266
initial_offset,
12351267
),
1236-
)
1237-
})
1238-
.collect()
1239-
})
1268+
))
1269+
}
1270+
})
1271+
.collect()
12401272
}
12411273

12421274
fn codegen_enum_case_struct(
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// Copyright Kani Contributors
2+
// SPDX-License-Identifier: Apache-2.0 OR MIT
3+
4+
//! Testcase for niche encoding where there are multiple variants but only one contains
5+
//! non-zero sized data with niche, making it a great candidate for niche optimization.
6+
#[derive(PartialEq)]
7+
enum MyEnum {
8+
NoFields,
9+
DataFul(bool),
10+
UnitFields((), ()),
11+
ZSTField(ZeroSized),
12+
ZSTStruct { field: ZeroSized, unit: () },
13+
}
14+
15+
#[derive(PartialEq)]
16+
struct ZeroSized {}
17+
18+
impl ZeroSized {
19+
fn works(&self) -> bool {
20+
true
21+
}
22+
}
23+
24+
impl MyEnum {
25+
fn create_no_field() -> MyEnum {
26+
MyEnum::NoFields
27+
}
28+
29+
fn create_data_ful(data: bool) -> MyEnum {
30+
MyEnum::DataFul(data)
31+
}
32+
33+
fn create_unit() -> MyEnum {
34+
MyEnum::UnitFields((), ())
35+
}
36+
37+
fn create_zst_field() -> MyEnum {
38+
MyEnum::ZSTField(ZeroSized {})
39+
}
40+
41+
fn create_zst_struct() -> MyEnum {
42+
MyEnum::ZSTStruct { field: ZeroSized {}, unit: () }
43+
}
44+
}
45+
46+
/// Ensure we are testing a case of niche optimization.
47+
#[kani::proof]
48+
fn check_is_niche() {
49+
assert_eq!(std::mem::size_of::<MyEnum>(), 1);
50+
assert_eq!(std::mem::size_of::<bool>(), 1);
51+
}
52+
53+
/// Check the behavior for the variant without any field.
54+
#[kani::proof]
55+
fn check_niche_no_fields() {
56+
let x = MyEnum::create_no_field();
57+
assert!(matches!(x, MyEnum::NoFields));
58+
}
59+
60+
/// Check the behavior for the dataful variant.
61+
#[kani::proof]
62+
fn check_niche_data_ful() {
63+
let x = MyEnum::create_data_ful(true);
64+
assert!(matches!(x, MyEnum::DataFul(true)));
65+
}
66+
67+
/// Check the behavior for the variant with multiple unit fields.
68+
#[kani::proof]
69+
fn check_niche_unit_fields() {
70+
let x = MyEnum::create_unit();
71+
assert_eq!(x, MyEnum::UnitFields((), ()));
72+
if let MyEnum::UnitFields(ref v, ..) = &x {
73+
assert_eq!(std::mem::size_of_val(v), 0);
74+
}
75+
}
76+
77+
/// Check the behavior for the variant with one ZST field.
78+
#[kani::proof]
79+
fn check_niche_zst_field() {
80+
let x = MyEnum::create_zst_field();
81+
assert_eq!(x, MyEnum::ZSTField(ZeroSized {}));
82+
if let MyEnum::ZSTField(ref field) = &x {
83+
assert!(field.works());
84+
}
85+
}
86+
87+
/// Check the behavior for the variant representing a struct with one ZST field.
88+
#[kani::proof]
89+
fn check_niche_zst_struct() {
90+
let x = MyEnum::create_zst_struct();
91+
assert!(matches!(x, MyEnum::ZSTStruct { .. }));
92+
if let MyEnum::ZSTStruct { ref field, ref unit } = &x {
93+
assert_eq!(std::mem::size_of_val(unit), 0);
94+
assert!(field.works());
95+
}
96+
}

tests/ui/filter-sanity-checks/expected

Lines changed: 0 additions & 3 deletions
This file was deleted.

tests/ui/filter-sanity-checks/sanity_check_fail.rs

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)