Skip to content

Commit 5f87957

Browse files
AdamMagierFOSSAdam Magier
andauthored
[clang][CodeGen][UBSan] Fixing shift-exponent generation for _BitInt (#80515)
Testing the shift-exponent check with small width _BitInt values exposed a bug in ScalarExprEmitter::GetWidthMinusOneValue when using the result to determine valid exponent sizes. False positives were reported for some left shifts when width(LHS)-1 > range(RHS) and false negatives were reported for right shifts when value(RHS) > range(LHS). This patch caps the maximum value of GetWidthMinusOneValue to fit within range(RHS) to fix the issue with left shifts and fixes a code generation in EmitShr to fix the issue with right shifts and renames the function to GetMaximumShiftAmount to better reflect the new behaviour. Fixes #80135. Co-authored-by: Adam Magier <[email protected]>
1 parent 6812bc4 commit 5f87957

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,7 @@ class ScalarExprEmitter
774774
void EmitUndefinedBehaviorIntegerDivAndRemCheck(const BinOpInfo &Ops,
775775
llvm::Value *Zero,bool isDiv);
776776
// Common helper for getting how wide LHS of shift is.
777-
static Value *GetWidthMinusOneValue(Value* LHS,Value* RHS);
777+
static Value *GetMaximumShiftAmount(Value *LHS, Value *RHS);
778778

779779
// Used for shifting constraints for OpenCL, do mask for powers of 2, URem for
780780
// non powers of two.
@@ -4115,13 +4115,21 @@ Value *ScalarExprEmitter::EmitSub(const BinOpInfo &op) {
41154115
return Builder.CreateExactSDiv(diffInChars, divisor, "sub.ptr.div");
41164116
}
41174117

4118-
Value *ScalarExprEmitter::GetWidthMinusOneValue(Value* LHS,Value* RHS) {
4118+
Value *ScalarExprEmitter::GetMaximumShiftAmount(Value *LHS, Value *RHS) {
41194119
llvm::IntegerType *Ty;
41204120
if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(LHS->getType()))
41214121
Ty = cast<llvm::IntegerType>(VT->getElementType());
41224122
else
41234123
Ty = cast<llvm::IntegerType>(LHS->getType());
4124-
return llvm::ConstantInt::get(RHS->getType(), Ty->getBitWidth() - 1);
4124+
// For a given type of LHS the maximum shift amount is width(LHS)-1, however
4125+
// it can occur that width(LHS)-1 > range(RHS). Since there is no check for
4126+
// this in ConstantInt::get, this results in the value getting truncated.
4127+
// Constrain the return value to be max(RHS) in this case.
4128+
llvm::Type *RHSTy = RHS->getType();
4129+
llvm::APInt RHSMax = llvm::APInt::getMaxValue(RHSTy->getScalarSizeInBits());
4130+
if (RHSMax.ult(Ty->getBitWidth()))
4131+
return llvm::ConstantInt::get(RHSTy, RHSMax);
4132+
return llvm::ConstantInt::get(RHSTy, Ty->getBitWidth() - 1);
41254133
}
41264134

41274135
Value *ScalarExprEmitter::ConstrainShiftValue(Value *LHS, Value *RHS,
@@ -4133,7 +4141,7 @@ Value *ScalarExprEmitter::ConstrainShiftValue(Value *LHS, Value *RHS,
41334141
Ty = cast<llvm::IntegerType>(LHS->getType());
41344142

41354143
if (llvm::isPowerOf2_64(Ty->getBitWidth()))
4136-
return Builder.CreateAnd(RHS, GetWidthMinusOneValue(LHS, RHS), Name);
4144+
return Builder.CreateAnd(RHS, GetMaximumShiftAmount(LHS, RHS), Name);
41374145

41384146
return Builder.CreateURem(
41394147
RHS, llvm::ConstantInt::get(RHS->getType(), Ty->getBitWidth()), Name);
@@ -4166,7 +4174,7 @@ Value *ScalarExprEmitter::EmitShl(const BinOpInfo &Ops) {
41664174
isa<llvm::IntegerType>(Ops.LHS->getType())) {
41674175
CodeGenFunction::SanitizerScope SanScope(&CGF);
41684176
SmallVector<std::pair<Value *, SanitizerMask>, 2> Checks;
4169-
llvm::Value *WidthMinusOne = GetWidthMinusOneValue(Ops.LHS, Ops.RHS);
4177+
llvm::Value *WidthMinusOne = GetMaximumShiftAmount(Ops.LHS, Ops.RHS);
41704178
llvm::Value *ValidExponent = Builder.CreateICmpULE(Ops.RHS, WidthMinusOne);
41714179

41724180
if (SanitizeExponent) {
@@ -4184,7 +4192,7 @@ Value *ScalarExprEmitter::EmitShl(const BinOpInfo &Ops) {
41844192
Builder.CreateCondBr(ValidExponent, CheckShiftBase, Cont);
41854193
llvm::Value *PromotedWidthMinusOne =
41864194
(RHS == Ops.RHS) ? WidthMinusOne
4187-
: GetWidthMinusOneValue(Ops.LHS, RHS);
4195+
: GetMaximumShiftAmount(Ops.LHS, RHS);
41884196
CGF.EmitBlock(CheckShiftBase);
41894197
llvm::Value *BitsShiftedOff = Builder.CreateLShr(
41904198
Ops.LHS, Builder.CreateSub(PromotedWidthMinusOne, RHS, "shl.zeros",
@@ -4235,7 +4243,7 @@ Value *ScalarExprEmitter::EmitShr(const BinOpInfo &Ops) {
42354243
isa<llvm::IntegerType>(Ops.LHS->getType())) {
42364244
CodeGenFunction::SanitizerScope SanScope(&CGF);
42374245
llvm::Value *Valid =
4238-
Builder.CreateICmpULE(RHS, GetWidthMinusOneValue(Ops.LHS, RHS));
4246+
Builder.CreateICmpULE(Ops.RHS, GetMaximumShiftAmount(Ops.LHS, Ops.RHS));
42394247
EmitBinOpCheck(std::make_pair(Valid, SanitizerKind::ShiftExponent), Ops);
42404248
}
42414249

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: %clang_cc1 %s -O0 -fsanitize=shift-exponent -emit-llvm -std=c2x -triple=x86_64-unknown-linux -o - | FileCheck %s
2+
3+
// Checking that the code generation is using the unextended/untruncated
4+
// exponent values and capping the values accordingly
5+
6+
// CHECK-LABEL: define{{.*}} i32 @test_left_variable
7+
int test_left_variable(unsigned _BitInt(5) b, unsigned _BitInt(2) e) {
8+
// CHECK: [[E_REG:%.+]] = load [[E_SIZE:i2]]
9+
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], -1
10+
return b << e;
11+
}
12+
13+
// CHECK-LABEL: define{{.*}} i32 @test_right_variable
14+
int test_right_variable(unsigned _BitInt(2) b, unsigned _BitInt(3) e) {
15+
// CHECK: [[E_REG:%.+]] = load [[E_SIZE:i3]]
16+
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], 1
17+
return b >> e;
18+
}
19+
20+
// Old code generation would give false positives on left shifts when:
21+
// value(e) > (width(b) - 1 % 2 ** width(e))
22+
// CHECK-LABEL: define{{.*}} i32 @test_left_literal
23+
int test_left_literal(unsigned _BitInt(5) b) {
24+
// CHECK-NOT: br i1 false, label %cont, label %handler.shift_out_of_bounds
25+
// CHECK: br i1 true, label %cont, label %handler.shift_out_of_bounds
26+
return b << 3uwb;
27+
}
28+
29+
// Old code generation would give false positives on right shifts when:
30+
// (value(e) % 2 ** width(b)) < width(b)
31+
// CHECK-LABEL: define{{.*}} i32 @test_right_literal
32+
int test_right_literal(unsigned _BitInt(2) b) {
33+
// CHECK-NOT: br i1 true, label %cont, label %handler.shift_out_of_bounds
34+
// CHECK: br i1 false, label %cont, label %handler.shift_out_of_bounds
35+
return b >> 4uwb;
36+
}

0 commit comments

Comments
 (0)