Skip to content

Commit df9c120

Browse files
committed
[clang][CodeGen] Fix shift-exponent ubsan check for signed _BitInt
Commit 5f87957 (pull-requst llvm#80515) corrected some codegen problems related to _BitInt types being used as shift exponents. But it did not fix it properly for the special case when the shift count operand is a signed _BitInt. The basic problem is the same as the one solved for unsigned _BitInt. As we use an unsigned comparison to see if the shift exponent is out-of-bounds, then we need to find an unsigned maximum allowed shift amount to use in the check. Normally the shift amount is limited by bitwidth of the LHS of the shift. However, when the RHS type is small in relation to the LHS then we need to use a value that fits inside the bitwidth of the RHS instead. The earlier fix simply used the unsigned maximum when deterining the max shift amount based on the RHS type. It did however not take into consideration that the RHS type could have a signed representation. In such situations we need to use the signed maximum instead. Otherwise we do not recognize a negative shift exponent as UB.
1 parent 8ccf1c1 commit df9c120

File tree

2 files changed

+54
-10
lines changed

2 files changed

+54
-10
lines changed

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,15 @@ struct BinOpInfo {
146146
return UnOp->getSubExpr()->getType()->isFixedPointType();
147147
return false;
148148
}
149+
150+
/// Check if the RHS has a signed integer representation.
151+
bool rhsHasSignedIntegerRepresentation() const {
152+
if (const auto *BinOp = dyn_cast<BinaryOperator>(E)) {
153+
QualType RHSType = BinOp->getRHS()->getType();
154+
return RHSType->hasSignedIntegerRepresentation();
155+
}
156+
return false;
157+
}
149158
};
150159

151160
static bool MustVisitNullValue(const Expr *E) {
@@ -780,7 +789,7 @@ class ScalarExprEmitter
780789
void EmitUndefinedBehaviorIntegerDivAndRemCheck(const BinOpInfo &Ops,
781790
llvm::Value *Zero,bool isDiv);
782791
// Common helper for getting how wide LHS of shift is.
783-
static Value *GetMaximumShiftAmount(Value *LHS, Value *RHS);
792+
static Value *GetMaximumShiftAmount(Value *LHS, Value *RHS, bool RHSIsSigned);
784793

785794
// Used for shifting constraints for OpenCL, do mask for powers of 2, URem for
786795
// non powers of two.
@@ -4181,7 +4190,8 @@ Value *ScalarExprEmitter::EmitSub(const BinOpInfo &op) {
41814190
return Builder.CreateExactSDiv(diffInChars, divisor, "sub.ptr.div");
41824191
}
41834192

4184-
Value *ScalarExprEmitter::GetMaximumShiftAmount(Value *LHS, Value *RHS) {
4193+
Value *ScalarExprEmitter::GetMaximumShiftAmount(Value *LHS, Value *RHS,
4194+
bool RHSIsSigned) {
41854195
llvm::IntegerType *Ty;
41864196
if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(LHS->getType()))
41874197
Ty = cast<llvm::IntegerType>(VT->getElementType());
@@ -4192,7 +4202,9 @@ Value *ScalarExprEmitter::GetMaximumShiftAmount(Value *LHS, Value *RHS) {
41924202
// this in ConstantInt::get, this results in the value getting truncated.
41934203
// Constrain the return value to be max(RHS) in this case.
41944204
llvm::Type *RHSTy = RHS->getType();
4195-
llvm::APInt RHSMax = llvm::APInt::getMaxValue(RHSTy->getScalarSizeInBits());
4205+
llvm::APInt RHSMax =
4206+
RHSIsSigned ? llvm::APInt::getSignedMaxValue(RHSTy->getScalarSizeInBits())
4207+
: llvm::APInt::getMaxValue(RHSTy->getScalarSizeInBits());
41964208
if (RHSMax.ult(Ty->getBitWidth()))
41974209
return llvm::ConstantInt::get(RHSTy, RHSMax);
41984210
return llvm::ConstantInt::get(RHSTy, Ty->getBitWidth() - 1);
@@ -4207,7 +4219,7 @@ Value *ScalarExprEmitter::ConstrainShiftValue(Value *LHS, Value *RHS,
42074219
Ty = cast<llvm::IntegerType>(LHS->getType());
42084220

42094221
if (llvm::isPowerOf2_64(Ty->getBitWidth()))
4210-
return Builder.CreateAnd(RHS, GetMaximumShiftAmount(LHS, RHS), Name);
4222+
return Builder.CreateAnd(RHS, GetMaximumShiftAmount(LHS, RHS, false), Name);
42114223

42124224
return Builder.CreateURem(
42134225
RHS, llvm::ConstantInt::get(RHS->getType(), Ty->getBitWidth()), Name);
@@ -4240,7 +4252,9 @@ Value *ScalarExprEmitter::EmitShl(const BinOpInfo &Ops) {
42404252
isa<llvm::IntegerType>(Ops.LHS->getType())) {
42414253
CodeGenFunction::SanitizerScope SanScope(&CGF);
42424254
SmallVector<std::pair<Value *, SanitizerMask>, 2> Checks;
4243-
llvm::Value *WidthMinusOne = GetMaximumShiftAmount(Ops.LHS, Ops.RHS);
4255+
bool RHSIsSigned = Ops.rhsHasSignedIntegerRepresentation();
4256+
llvm::Value *WidthMinusOne =
4257+
GetMaximumShiftAmount(Ops.LHS, Ops.RHS, RHSIsSigned);
42444258
llvm::Value *ValidExponent = Builder.CreateICmpULE(Ops.RHS, WidthMinusOne);
42454259

42464260
if (SanitizeExponent) {
@@ -4258,7 +4272,7 @@ Value *ScalarExprEmitter::EmitShl(const BinOpInfo &Ops) {
42584272
Builder.CreateCondBr(ValidExponent, CheckShiftBase, Cont);
42594273
llvm::Value *PromotedWidthMinusOne =
42604274
(RHS == Ops.RHS) ? WidthMinusOne
4261-
: GetMaximumShiftAmount(Ops.LHS, RHS);
4275+
: GetMaximumShiftAmount(Ops.LHS, RHS, RHSIsSigned);
42624276
CGF.EmitBlock(CheckShiftBase);
42634277
llvm::Value *BitsShiftedOff = Builder.CreateLShr(
42644278
Ops.LHS, Builder.CreateSub(PromotedWidthMinusOne, RHS, "shl.zeros",
@@ -4308,8 +4322,9 @@ Value *ScalarExprEmitter::EmitShr(const BinOpInfo &Ops) {
43084322
else if (CGF.SanOpts.has(SanitizerKind::ShiftExponent) &&
43094323
isa<llvm::IntegerType>(Ops.LHS->getType())) {
43104324
CodeGenFunction::SanitizerScope SanScope(&CGF);
4311-
llvm::Value *Valid =
4312-
Builder.CreateICmpULE(Ops.RHS, GetMaximumShiftAmount(Ops.LHS, Ops.RHS));
4325+
bool RHSIsSigned = Ops.rhsHasSignedIntegerRepresentation();
4326+
llvm::Value *Valid = Builder.CreateICmpULE(
4327+
Ops.RHS, GetMaximumShiftAmount(Ops.LHS, Ops.RHS, RHSIsSigned));
43134328
EmitBinOpCheck(std::make_pair(Valid, SanitizerKind::ShiftExponent), Ops);
43144329
}
43154330

clang/test/CodeGen/ubsan-shift-bitint.c

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
// CHECK-LABEL: define{{.*}} i32 @test_left_variable
77
int test_left_variable(unsigned _BitInt(5) b, unsigned _BitInt(2) e) {
88
// CHECK: [[E_REG:%.+]] = load [[E_SIZE:i2]]
9-
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], -1
9+
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], -1,
1010
return b << e;
1111
}
1212

1313
// CHECK-LABEL: define{{.*}} i32 @test_right_variable
1414
int test_right_variable(unsigned _BitInt(2) b, unsigned _BitInt(3) e) {
1515
// CHECK: [[E_REG:%.+]] = load [[E_SIZE:i3]]
16-
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], 1
16+
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], 1,
1717
return b >> e;
1818
}
1919

@@ -34,3 +34,32 @@ int test_right_literal(unsigned _BitInt(2) b) {
3434
// CHECK: br i1 false, label %cont, label %handler.shift_out_of_bounds
3535
return b >> 4uwb;
3636
}
37+
38+
// CHECK-LABEL: define{{.*}} i32 @test_signed_left_variable
39+
int test_signed_left_variable(unsigned _BitInt(15) b, _BitInt(2) e) {
40+
// CHECK: [[E_REG:%.+]] = load [[E_SIZE:i2]]
41+
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], 1,
42+
return b << e;
43+
}
44+
45+
// CHECK-LABEL: define{{.*}} i32 @test_signed_right_variable
46+
int test_signed_right_variable(unsigned _BitInt(32) b, _BitInt(4) e) {
47+
// CHECK: [[E_REG:%.+]] = load [[E_SIZE:i4]]
48+
// CHECK: icmp ule [[E_SIZE]] [[E_REG]], 7,
49+
return b >> e;
50+
}
51+
52+
// CHECK-LABEL: define{{.*}} i32 @test_signed_left_literal
53+
int test_signed_left_literal(unsigned _BitInt(16) b) {
54+
// CHECK-NOT: br i1 true, label %cont, label %handler.shift_out_of_bounds
55+
// CHECK: br i1 false, label %cont, label %handler.shift_out_of_bounds
56+
return b << (_BitInt(4))-2wb;
57+
}
58+
59+
// CHECK-LABEL: define{{.*}} i32 @test_signed_right_literal
60+
int test_signed_right_literal(unsigned _BitInt(16) b) {
61+
// CHECK-NOT: br i1 true, label %cont, label %handler.shift_out_of_bounds
62+
// CHECK: br i1 false, label %cont, label %handler.shift_out_of_bounds
63+
return b >> (_BitInt(4))-8wb;
64+
}
65+

0 commit comments

Comments
 (0)