-
Notifications
You must be signed in to change notification settings - Fork 14k
[BasicAA] Consider 'nneg' flag when comparing CastedValues #94129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-analysis Author: Alex MacLean (AlexMaclean) ChangesAny of the Full diff: https://github.com/llvm/llvm-project/pull/94129.diff 2 Files Affected:
diff --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
index 3f456db1c51ac..826706d1306a9 100644
--- a/llvm/lib/Analysis/BasicAliasAnalysis.cpp
+++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
@@ -268,31 +268,42 @@ struct CastedValue {
unsigned ZExtBits = 0;
unsigned SExtBits = 0;
unsigned TruncBits = 0;
+ bool IsNonNegative = false;
explicit CastedValue(const Value *V) : V(V) {}
explicit CastedValue(const Value *V, unsigned ZExtBits, unsigned SExtBits,
- unsigned TruncBits)
- : V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits) {}
+ unsigned TruncBits, bool IsNonNegative)
+ : V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits),
+ IsNonNegative(IsNonNegative) {}
unsigned getBitWidth() const {
return V->getType()->getPrimitiveSizeInBits() - TruncBits + ZExtBits +
SExtBits;
}
- CastedValue withValue(const Value *NewV) const {
- return CastedValue(NewV, ZExtBits, SExtBits, TruncBits);
+ CastedValue withValue(const Value *NewV, bool PreserveNonNeg) const {
+ return CastedValue(NewV, ZExtBits, SExtBits, TruncBits,
+ IsNonNegative && PreserveNonNeg);
}
/// Replace V with zext(NewV)
- CastedValue withZExtOfValue(const Value *NewV) const {
+ CastedValue withZExtOfValue(const Value *NewV, bool ZExtNonNegative) const {
unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
NewV->getType()->getPrimitiveSizeInBits();
if (ExtendBy <= TruncBits)
- return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
+ // zext<nneg>(trunc(zext(NewV))) == zext<nneg>(trunc(NewV))
+ // The nneg can be preserved on the outer zext here
+ return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
+ IsNonNegative);
// zext(sext(zext(NewV))) == zext(zext(zext(NewV)))
ExtendBy -= TruncBits;
- return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0, 0);
+ // zext<nneg>(zext(NewV)) == zext(NewV)
+ // zext(zext<nneg>(NewV)) == zext<nneg>(NewV)
+ // The nneg can be preserved from the inner zext here but must be dropped
+ // from the outer.
+ return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0, 0,
+ ZExtNonNegative);
}
/// Replace V with sext(NewV)
@@ -300,11 +311,16 @@ struct CastedValue {
unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
NewV->getType()->getPrimitiveSizeInBits();
if (ExtendBy <= TruncBits)
- return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
+ // zext<nneg>(trunc(sext(NewV))) == zext<nneg>(trunc(NewV))
+ // The nneg can be preserved on the outer zext here
+ return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
+ IsNonNegative);
// zext(sext(sext(NewV)))
ExtendBy -= TruncBits;
- return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy, 0);
+ // zext<nneg>(sext(sext(NewV))) = zext<nneg>(sext(NewV))
+ // The nneg can be preserved on the outer zext here
+ return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy, 0, IsNonNegative);
}
APInt evaluateWith(APInt N) const {
@@ -333,8 +349,15 @@ struct CastedValue {
}
bool hasSameCastsAs(const CastedValue &Other) const {
- return ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
- TruncBits == Other.TruncBits;
+ if (ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
+ TruncBits == Other.TruncBits)
+ return true;
+ // If either CastedValue has a nneg zext then the sext/zext bits are
+ // interchangable for that value.
+ if (IsNonNegative || Other.IsNonNegative)
+ return (ZExtBits + SExtBits == Other.ZExtBits + Other.SExtBits &&
+ TruncBits == Other.TruncBits);
+ return false;
}
};
@@ -410,21 +433,21 @@ static LinearExpression GetLinearExpression(
[[fallthrough]];
case Instruction::Add: {
- E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+ E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
Depth + 1, AC, DT);
E.Offset += RHS;
E.IsNSW &= NSW;
break;
}
case Instruction::Sub: {
- E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+ E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
Depth + 1, AC, DT);
E.Offset -= RHS;
E.IsNSW &= NSW;
break;
}
case Instruction::Mul:
- E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+ E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
Depth + 1, AC, DT)
.mul(RHS, NSW);
break;
@@ -437,7 +460,7 @@ static LinearExpression GetLinearExpression(
if (RHS.getLimitedValue() > Val.getBitWidth())
return Val;
- E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+ E = GetLinearExpression(Val.withValue(BOp->getOperand(0), NSW), DL,
Depth + 1, AC, DT);
E.Offset <<= RHS.getLimitedValue();
E.Scale <<= RHS.getLimitedValue();
@@ -450,7 +473,8 @@ static LinearExpression GetLinearExpression(
if (isa<ZExtInst>(Val.V))
return GetLinearExpression(
- Val.withZExtOfValue(cast<CastInst>(Val.V)->getOperand(0)),
+ Val.withZExtOfValue(cast<CastInst>(Val.V)->getOperand(0),
+ cast<ZExtInst>(Val.V)->hasNonNeg()),
DL, Depth + 1, AC, DT);
if (isa<SExtInst>(Val.V))
@@ -673,7 +697,7 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
unsigned SExtBits = IndexSize > Width ? IndexSize - Width : 0;
unsigned TruncBits = IndexSize < Width ? Width - IndexSize : 0;
LinearExpression LE = GetLinearExpression(
- CastedValue(Index, 0, SExtBits, TruncBits), DL, 0, AC, DT);
+ CastedValue(Index, 0, SExtBits, TruncBits, false), DL, 0, AC, DT);
// Scale by the type size.
unsigned TypeSize = AllocTypeSize.getFixedValue();
diff --git a/llvm/test/Analysis/BasicAA/zext-nneg.ll b/llvm/test/Analysis/BasicAA/zext-nneg.ll
new file mode 100644
index 0000000000000..808bb1a8c9d96
--- /dev/null
+++ b/llvm/test/Analysis/BasicAA/zext-nneg.ll
@@ -0,0 +1,181 @@
+; RUN: opt < %s -aa-pipeline=basic-aa -passes=aa-eval -print-all-alias-modref-info -disable-output 2>&1 | FileCheck %s
+
+;; Simple case: a zext nneg can be replaced with a sext. Make sure BasicAA
+;; understands that.
+define void @t1(i32 %a, i32 %b) {
+; CHECK-LABEL: Function: t1
+; CHECK: NoAlias: float* %gep1, float* %gep2
+
+ %1 = alloca [8 x float], align 4
+ %or1 = or i32 %a, 1
+ %2 = sext i32 %or1 to i64
+ %gep1 = getelementptr inbounds float, ptr %1, i64 %2
+
+ %shl1 = shl i32 %b, 1
+ %3 = zext nneg i32 %shl1 to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %3
+
+ load float, ptr %gep1
+ load float, ptr %gep2
+ ret void
+}
+
+;; A (zext nneg (sext V)) is equivalent to a (zext (sext V)) as long as the
+;; total number of zext+sext bits is the same for both.
+define void @t2(i8 %a, i8 %b) {
+; CHECK-LABEL: Function: t2
+; CHECK: NoAlias: float* %gep1, float* %gep2
+ %1 = alloca [8 x float], align 4
+ %or1 = or i8 %a, 1
+ %2 = sext i8 %or1 to i32
+ %3 = zext i32 %2 to i64
+ %gep1 = getelementptr inbounds float, ptr %1, i64 %3
+
+ %shl1 = shl i8 %b, 1
+ %4 = sext i8 %shl1 to i16
+ %5 = zext nneg i16 %4 to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+ load float, ptr %gep1
+ load float, ptr %gep2
+ ret void
+}
+
+;; Here the %a and %b are knowably non-equal. In this cases we can distribute
+;; the zext, preserving the nneg flag, through the shl because it has a nsw flag
+define void @t3(i8 %v) {
+; CHECK-LABEL: Function: t3
+; CHECK: NoAlias: <2 x float>* %gep1, <2 x float>* %gep2
+ %a = or i8 %v, 1
+ %b = and i8 %v, 2
+
+ %1 = alloca [8 x float], align 4
+ %or1 = shl nuw nsw i8 %a, 1
+ %2 = zext nneg i8 %or1 to i64
+ %gep1 = getelementptr inbounds float, ptr %1, i64 %2
+
+ %m = mul nsw nuw i8 %b, 2
+ %3 = sext i8 %m to i16
+ %4 = zext i16 %3 to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %4
+
+ load <2 x float>, ptr %gep1
+ load <2 x float>, ptr %gep2
+ ret void
+}
+
+;; This is the same as above, but this time the shl does not have the nsw flag.
+;; the nneg cannot be kept on the zext.
+define void @t4(i8 %v) {
+; CHECK-LABEL: Function: t4
+; CHECK: MayAlias: <2 x float>* %gep1, <2 x float>* %gep2
+ %a = or i8 %v, 1
+ %b = and i8 %v, 2
+
+ %1 = alloca [8 x float], align 4
+ %or1 = shl nuw i8 %a, 1
+ %2 = zext nneg i8 %or1 to i64
+ %gep1 = getelementptr inbounds float, ptr %1, i64 %2
+
+ %m = mul nsw nuw i8 %b, 2
+ %3 = sext i8 %m to i16
+ %4 = zext i16 %3 to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %4
+
+ load <2 x float>, ptr %gep1
+ load <2 x float>, ptr %gep2
+ ret void
+}
+
+;; Verify a zext nneg and a zext are understood as the same
+define void @t5(ptr %p, i16 %i) {
+; CHECK-LABEL: Function: t5
+; CHECK: NoAlias: i32* %pi, i32* %pi.next
+ %i1 = zext nneg i16 %i to i32
+ %pi = getelementptr i32, ptr %p, i32 %i1
+
+ %i.next = add i16 %i, 1
+ %i.next2 = zext i16 %i.next to i32
+ %pi.next = getelementptr i32, ptr %p, i32 %i.next2
+
+ load i32, ptr %pi
+ load i32, ptr %pi.next
+ ret void
+}
+
+;; This is not very idiomatic, but still possible, verify the nneg is propagated
+;; outward. and that no alias is correctly identified.
+define void @t6(i8 %a) {
+; CHECK-LABEL: Function: t6
+; CHECK: NoAlias: float* %gep1, float* %gep2
+ %1 = alloca [8 x float], align 4
+ %a.add = add i8 %a, 1
+ %2 = zext nneg i8 %a.add to i16
+ %3 = sext i16 %2 to i32
+ %4 = zext i32 %3 to i64
+ %gep1 = getelementptr inbounds float, ptr %1, i64 %4
+
+ %5 = sext i8 %a to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+ load float, ptr %gep1
+ load float, ptr %gep2
+ ret void
+}
+
+;; This is even less idiomatic, but still possible, verify the nneg is not
+;; propagated inward. and that may alias is correctly identified.
+define void @t7(i8 %a) {
+; CHECK-LABEL: Function: t7
+; CHECK: MayAlias: float* %gep1, float* %gep2
+ %1 = alloca [8 x float], align 4
+ %a.add = add i8 %a, 1
+ %2 = zext i8 %a.add to i16
+ %3 = sext i16 %2 to i32
+ %4 = zext nneg i32 %3 to i64
+ %gep1 = getelementptr inbounds float, ptr %1, i64 %4
+
+ %5 = sext i8 %a to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+ load float, ptr %gep1
+ load float, ptr %gep2
+ ret void
+}
+
+;; Verify the nneg survives an implicit trunc of fewer bits then the zext.
+define void @t8(i8 %a) {
+; CHECK-LABEL: Function: t8
+; CHECK: NoAlias: float* %gep1, float* %gep2
+ %1 = alloca [8 x float], align 4
+ %a.add = add i8 %a, 1
+ %2 = zext nneg i8 %a.add to i128
+ %gep1 = getelementptr inbounds float, ptr %1, i128 %2
+
+ %3 = sext i8 %a to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %3
+
+ load float, ptr %gep1
+ load float, ptr %gep2
+ ret void
+}
+
+;; Ensure that the nneg is never propagated past this trunc and that these
+;; casted values are understood as non-equal.
+define void @t9(i8 %a) {
+; CHECK-LABEL: Function: t9
+; CHECK: MayAlias: float* %gep1, float* %gep2
+ %1 = alloca [8 x float], align 4
+ %a.add = add i8 %a, 1
+ %2 = zext i8 %a.add to i16
+ %3 = trunc i16 %2 to i1
+ %4 = zext nneg i1 %3 to i64
+ %gep1 = getelementptr inbounds float, ptr %1, i64 %4
+
+ %5 = sext i8 %a to i64
+ %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+ load float, ptr %gep1
+ load float, ptr %gep2
+ ret void
+}
|
@@ -268,43 +268,59 @@ struct CastedValue { | |||
unsigned ZExtBits = 0; | |||
unsigned SExtBits = 0; | |||
unsigned TruncBits = 0; | |||
bool IsNonNegative = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Whether trunc(V) is non-negative.
Should clarify that this does not apply to V
.
unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() - | ||
NewV->getType()->getPrimitiveSizeInBits(); | ||
if (ExtendBy <= TruncBits) | ||
return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy); | ||
// zext<nneg>(trunc(zext(NewV))) == zext<nneg>(trunc(NewV)) | ||
// The nneg can be preserved on the outer zext here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// The nneg can be preserved on the outer zext here | |
// The nneg can be preserved on the outer zext here. |
@@ -450,7 +473,8 @@ static LinearExpression GetLinearExpression( | |||
|
|||
if (isa<ZExtInst>(Val.V)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dyn_cast here instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM apart from the nits. (I deleted some other review comments that I changed my mind on.)
Any of the
zext
bits in azext nneg
can be converted tosext
but when checking if casts are compatibleBasicAA
fails to take into accountnneg
. This change adds tracking ofnneg
to theCastedValue
struct and ensures thatsext
andzext
bits are treated as interchangeable when eitherCastedValue
has anneg
. When distributing casted values inGetLinearExpression
we conservatively discard thenneg
from theCastedValue
, except in the case ofshl nsw
, where we know the sign has not changed to negative.