Skip to content

Commit 20667db

Browse files
authored
[clang][CodeGen] Fix shift-exponent ubsan check for signed _BitInt (#88004)
Commit 5f87957 (pull-requst #80515) corrected some codegen problems related to _BitInt types being used as shift exponents. But it did not fix it properly for the special case when the shift count operand is a signed _BitInt. The basic problem is the same as the one solved for unsigned _BitInt. As we use an unsigned comparison to see if the shift exponent is out-of-bounds, then we need to find an unsigned maximum allowed shift amount to use in the check. Normally the shift amount is limited by bitwidth of the LHS of the shift. However, when the RHS type is small in relation to the LHS then we need to use a value that fits inside the bitwidth of the RHS instead. The earlier fix simply used the unsigned maximum when deterining the max shift amount based on the RHS type. It did however not take into consideration that the RHS type could have a signed representation. In such situations we need to use the signed maximum instead. Otherwise we do not recognize a negative shift exponent as UB.
1 parent bd898d5 commit 20667db

File tree

2 files changed

+54
-10
lines changed

2 files changed

+54
-10
lines changed

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,15 @@ struct BinOpInfo {
147147
return UnOp->getSubExpr()->getType()->isFixedPointType();
148148
return false;
149149
}
150+
151+
/// Check if the RHS has a signed integer representation.
152+
bool rhsHasSignedIntegerRepresentation() const {
153+
if (const auto *BinOp = dyn_cast<BinaryOperator>(E)) {
154+
QualType RHSType = BinOp->getRHS()->getType();
155+
return RHSType->hasSignedIntegerRepresentation();
156+
}
157+
return false;
158+
}
150159
};
151160

152161
static bool MustVisitNullValue(const Expr *E) {
@@ -782,7 +791,7 @@ class ScalarExprEmitter
782791
void EmitUndefinedBehaviorIntegerDivAndRemCheck(const BinOpInfo &Ops,
783792
llvm::Value *Zero,bool isDiv);
784793
// Common helper for getting how wide LHS of shift is.
785-
static Value *GetMaximumShiftAmount(Value *LHS, Value *RHS);
794+
static Value *GetMaximumShiftAmount(Value *LHS, Value *RHS, bool RHSIsSigned);
786795

787796
// Used for shifting constraints for OpenCL, do mask for powers of 2, URem for
788797
// non powers of two.
@@ -4344,7 +4353,8 @@ Value *ScalarExprEmitter::EmitSub(const BinOpInfo &op) {
43444353
return Builder.CreateExactSDiv(diffInChars, divisor, "sub.ptr.div");
43454354
}
43464355

4347-
Value *ScalarExprEmitter::GetMaximumShiftAmount(Value *LHS, Value *RHS) {
4356+
Value *ScalarExprEmitter::GetMaximumShiftAmount(Value *LHS, Value *RHS,
4357+
bool RHSIsSigned) {
43484358
llvm::IntegerType *Ty;
43494359
if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(LHS->getType()))
43504360
Ty = cast<llvm::IntegerType>(VT->getElementType());
@@ -4355,7 +4365,9 @@ Value *ScalarExprEmitter::GetMaximumShiftAmount(Value *LHS, Value *RHS) {
43554365
// this in ConstantInt::get, this results in the value getting truncated.
43564366
// Constrain the return value to be max(RHS) in this case.
43574367
llvm::Type *RHSTy = RHS->getType();
4358-
llvm::APInt RHSMax = llvm::APInt::getMaxValue(RHSTy->getScalarSizeInBits());
4368+
llvm::APInt RHSMax =
4369+
RHSIsSigned ? llvm::APInt::getSignedMaxValue(RHSTy->getScalarSizeInBits())
4370+
: llvm::APInt::getMaxValue(RHSTy->getScalarSizeInBits());
43594371
if (RHSMax.ult(Ty->getBitWidth()))
43604372
return llvm::ConstantInt::get(RHSTy, RHSMax);
43614373
return llvm::ConstantInt::get(RHSTy, Ty->getBitWidth() - 1);
@@ -4370,7 +4382,7 @@ Value *ScalarExprEmitter::ConstrainShiftValue(Value *LHS, Value *RHS,
43704382
Ty = cast<llvm::IntegerType>(LHS->getType());
43714383

43724384
if (llvm::isPowerOf2_64(Ty->getBitWidth()))
4373-
return Builder.CreateAnd(RHS, GetMaximumShiftAmount(LHS, RHS), Name);
4385+
return Builder.CreateAnd(RHS, GetMaximumShiftAmount(LHS, RHS, false), Name);
43744386

43754387
return Builder.CreateURem(
43764388
RHS, llvm::ConstantInt::get(RHS->getType(), Ty->getBitWidth()), Name);
@@ -4403,7 +4415,9 @@ Value *ScalarExprEmitter::EmitShl(const BinOpInfo &Ops) {
44034415
isa<llvm::IntegerType>(Ops.LHS->getType())) {
44044416
CodeGenFunction::SanitizerScope SanScope(&CGF);
44054417
SmallVector<std::pair<Value *, SanitizerMask>, 2> Checks;
4406-
llvm::Value *WidthMinusOne = GetMaximumShiftAmount(Ops.LHS, Ops.RHS);
4418+
bool RHSIsSigned = Ops.rhsHasSignedIntegerRepresentation();
4419+
llvm::Value *WidthMinusOne =
4420+
GetMaximumShiftAmount(Ops.LHS, Ops.RHS, RHSIsSigned);
44074421
llvm::Value *ValidExponent = Builder.CreateICmpULE(Ops.RHS, WidthMinusOne);
44084422

44094423
if (SanitizeExponent) {
@@ -4421,7 +4435,7 @@ Value *ScalarExprEmitter::EmitShl(const BinOpInfo &Ops) {
44214435
Builder.CreateCondBr(ValidExponent, CheckShiftBase, Cont);
44224436
llvm::Value *PromotedWidthMinusOne =
44234437
(RHS == Ops.RHS) ? WidthMinusOne
4424-
: GetMaximumShiftAmount(Ops.LHS, RHS);
4438+
: GetMaximumShiftAmount(Ops.LHS, RHS, RHSIsSigned);
44254439
CGF.EmitBlock(CheckShiftBase);
44264440
llvm::Value *BitsShiftedOff = Builder.CreateLShr(
44274441
Ops.LHS, Builder.CreateSub(PromotedWidthMinusOne, RHS, "shl.zeros",
@@ -4471,8 +4485,9 @@ Value *ScalarExprEmitter::EmitShr(const BinOpInfo &Ops) {
44714485
else if (CGF.SanOpts.has(SanitizerKind::ShiftExponent) &&
44724486
isa<llvm::IntegerType>(Ops.LHS->getType())) {
44734487
CodeGenFunction::SanitizerScope SanScope(&CGF);
4474-
llvm::Value *Valid =
4475-
Builder.CreateICmpULE(Ops.RHS, GetMaximumShiftAmount(Ops.LHS, Ops.RHS));
4488+
bool RHSIsSigned = Ops.rhsHasSignedIntegerRepresentation();
4489+
llvm::Value *Valid = Builder.CreateICmpULE(
4490+
Ops.RHS, GetMaximumShiftAmount(Ops.LHS, Ops.RHS, RHSIsSigned));
44764491
EmitBinOpCheck(std::make_pair(Valid, SanitizerKind::ShiftExponent), Ops);
44774492
}
44784493

clang/test/CodeGen/ubsan-shift-bitint.c

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
// CHECK-LABEL: define{{.*}} i32 @test_left_variable
77
int test_left_variable(unsigned _BitInt(5) b, unsigned _BitInt(2) e) {
88
// CHECK: [[E_REG:%.+]] = load [[E_SIZE:i2]]
9-
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], -1
9+
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], -1,
1010
return b << e;
1111
}
1212

1313
// CHECK-LABEL: define{{.*}} i32 @test_right_variable
1414
int test_right_variable(unsigned _BitInt(2) b, unsigned _BitInt(3) e) {
1515
// CHECK: [[E_REG:%.+]] = load [[E_SIZE:i3]]
16-
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], 1
16+
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], 1,
1717
return b >> e;
1818
}
1919

@@ -34,3 +34,32 @@ int test_right_literal(unsigned _BitInt(2) b) {
3434
// CHECK: br i1 false, label %cont, label %handler.shift_out_of_bounds
3535
return b >> 4uwb;
3636
}
37+
38+
// CHECK-LABEL: define{{.*}} i32 @test_signed_left_variable
39+
int test_signed_left_variable(unsigned _BitInt(15) b, _BitInt(2) e) {
40+
// CHECK: [[E_REG:%.+]] = load [[E_SIZE:i2]]
41+
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], 1,
42+
return b << e;
43+
}
44+
45+
// CHECK-LABEL: define{{.*}} i32 @test_signed_right_variable
46+
int test_signed_right_variable(unsigned _BitInt(32) b, _BitInt(4) e) {
47+
// CHECK: [[E_REG:%.+]] = load [[E_SIZE:i4]]
48+
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], 7,
49+
return b >> e;
50+
}
51+
52+
// CHECK-LABEL: define{{.*}} i32 @test_signed_left_literal
53+
int test_signed_left_literal(unsigned _BitInt(16) b) {
54+
// CHECK-NOT: br i1 true, label %cont, label %handler.shift_out_of_bounds
55+
// CHECK: br i1 false, label %cont, label %handler.shift_out_of_bounds
56+
return b << (_BitInt(4))-2wb;
57+
}
58+
59+
// CHECK-LABEL: define{{.*}} i32 @test_signed_right_literal
60+
int test_signed_right_literal(unsigned _BitInt(16) b) {
61+
// CHECK-NOT: br i1 true, label %cont, label %handler.shift_out_of_bounds
62+
// CHECK: br i1 false, label %cont, label %handler.shift_out_of_bounds
63+
return b >> (_BitInt(4))-8wb;
64+
}
65+

0 commit comments

Comments
 (0)