Skip to content

[BasicAA] Consider 'nneg' flag when comparing CastedValues #94129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 4, 2024

Conversation

AlexMaclean
Copy link
Member

Any of the zext bits in a zext nneg can be converted to sext but when checking if casts are compatible BasicAA fails to take into account nneg. This change adds tracking of nneg to the CastedValue struct and ensures that sext and zext bits are treated as interchangeable when either CastedValue has a nneg. When distributing casted values in GetLinearExpression we conservatively discard the nneg from the CastedValue, except in the case of shl nsw, where we know the sign has not changed to negative.

@AlexMaclean AlexMaclean self-assigned this Jun 2, 2024
@AlexMaclean AlexMaclean requested a review from nikic as a code owner June 2, 2024 00:39
@llvmbot
Copy link
Member

llvmbot commented Jun 2, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Alex MacLean (AlexMaclean)

Changes

Any of the zext bits in a zext nneg can be converted to sext but when checking if casts are compatible BasicAA fails to take into account nneg. This change adds tracking of nneg to the CastedValue struct and ensures that sext and zext bits are treated as interchangeable when either CastedValue has a nneg. When distributing casted values in GetLinearExpression we conservatively discard the nneg from the CastedValue, except in the case of shl nsw, where we know the sign has not changed to negative.


Full diff: https://github.com/llvm/llvm-project/pull/94129.diff

2 Files Affected:

  • (modified) llvm/lib/Analysis/BasicAliasAnalysis.cpp (+41-17)
  • (added) llvm/test/Analysis/BasicAA/zext-nneg.ll (+181)
diff --git a/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
index 3f456db1c51ac..826706d1306a9 100644
--- a/llvm/lib/Analysis/BasicAliasAnalysis.cpp
+++ b/llvm/lib/Analysis/BasicAliasAnalysis.cpp
@@ -268,31 +268,42 @@ struct CastedValue {
   unsigned ZExtBits = 0;
   unsigned SExtBits = 0;
   unsigned TruncBits = 0;
+  bool IsNonNegative = false;
 
   explicit CastedValue(const Value *V) : V(V) {}
   explicit CastedValue(const Value *V, unsigned ZExtBits, unsigned SExtBits,
-                       unsigned TruncBits)
-      : V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits) {}
+                       unsigned TruncBits, bool IsNonNegative)
+      : V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits),
+        IsNonNegative(IsNonNegative) {}
 
   unsigned getBitWidth() const {
     return V->getType()->getPrimitiveSizeInBits() - TruncBits + ZExtBits +
            SExtBits;
   }
 
-  CastedValue withValue(const Value *NewV) const {
-    return CastedValue(NewV, ZExtBits, SExtBits, TruncBits);
+  CastedValue withValue(const Value *NewV, bool PreserveNonNeg) const {
+    return CastedValue(NewV, ZExtBits, SExtBits, TruncBits,
+                       IsNonNegative && PreserveNonNeg);
   }
 
   /// Replace V with zext(NewV)
-  CastedValue withZExtOfValue(const Value *NewV) const {
+  CastedValue withZExtOfValue(const Value *NewV, bool ZExtNonNegative) const {
     unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
                         NewV->getType()->getPrimitiveSizeInBits();
     if (ExtendBy <= TruncBits)
-      return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
+      // zext<nneg>(trunc(zext(NewV))) == zext<nneg>(trunc(NewV))
+      // The nneg can be preserved on the outer zext here
+      return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
+                         IsNonNegative);
 
     // zext(sext(zext(NewV))) == zext(zext(zext(NewV)))
     ExtendBy -= TruncBits;
-    return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0, 0);
+    // zext<nneg>(zext(NewV)) == zext(NewV)
+    // zext(zext<nneg>(NewV)) == zext<nneg>(NewV)
+    // The nneg can be preserved from the inner zext here but must be dropped
+    // from the outer.
+    return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0, 0,
+                       ZExtNonNegative);
   }
 
   /// Replace V with sext(NewV)
@@ -300,11 +311,16 @@ struct CastedValue {
     unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
                         NewV->getType()->getPrimitiveSizeInBits();
     if (ExtendBy <= TruncBits)
-      return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
+      // zext<nneg>(trunc(sext(NewV))) == zext<nneg>(trunc(NewV))
+      // The nneg can be preserved on the outer zext here
+      return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
+                         IsNonNegative);
 
     // zext(sext(sext(NewV)))
     ExtendBy -= TruncBits;
-    return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy, 0);
+    // zext<nneg>(sext(sext(NewV))) = zext<nneg>(sext(NewV))
+    // The nneg can be preserved on the outer zext here
+    return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy, 0, IsNonNegative);
   }
 
   APInt evaluateWith(APInt N) const {
@@ -333,8 +349,15 @@ struct CastedValue {
   }
 
   bool hasSameCastsAs(const CastedValue &Other) const {
-    return ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
-           TruncBits == Other.TruncBits;
+    if (ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
+        TruncBits == Other.TruncBits)
+      return true;
+    // If either CastedValue has a nneg zext then the sext/zext bits are
+    // interchangable for that value.
+    if (IsNonNegative || Other.IsNonNegative)
+      return (ZExtBits + SExtBits == Other.ZExtBits + Other.SExtBits &&
+              TruncBits == Other.TruncBits);
+    return false;
   }
 };
 
@@ -410,21 +433,21 @@ static LinearExpression GetLinearExpression(
 
         [[fallthrough]];
       case Instruction::Add: {
-        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
                                 Depth + 1, AC, DT);
         E.Offset += RHS;
         E.IsNSW &= NSW;
         break;
       }
       case Instruction::Sub: {
-        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
                                 Depth + 1, AC, DT);
         E.Offset -= RHS;
         E.IsNSW &= NSW;
         break;
       }
       case Instruction::Mul:
-        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
                                 Depth + 1, AC, DT)
                 .mul(RHS, NSW);
         break;
@@ -437,7 +460,7 @@ static LinearExpression GetLinearExpression(
         if (RHS.getLimitedValue() > Val.getBitWidth())
           return Val;
 
-        E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
+        E = GetLinearExpression(Val.withValue(BOp->getOperand(0), NSW), DL,
                                 Depth + 1, AC, DT);
         E.Offset <<= RHS.getLimitedValue();
         E.Scale <<= RHS.getLimitedValue();
@@ -450,7 +473,8 @@ static LinearExpression GetLinearExpression(
 
   if (isa<ZExtInst>(Val.V))
     return GetLinearExpression(
-        Val.withZExtOfValue(cast<CastInst>(Val.V)->getOperand(0)),
+        Val.withZExtOfValue(cast<CastInst>(Val.V)->getOperand(0),
+                            cast<ZExtInst>(Val.V)->hasNonNeg()),
         DL, Depth + 1, AC, DT);
 
   if (isa<SExtInst>(Val.V))
@@ -673,7 +697,7 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
       unsigned SExtBits = IndexSize > Width ? IndexSize - Width : 0;
       unsigned TruncBits = IndexSize < Width ? Width - IndexSize : 0;
       LinearExpression LE = GetLinearExpression(
-          CastedValue(Index, 0, SExtBits, TruncBits), DL, 0, AC, DT);
+          CastedValue(Index, 0, SExtBits, TruncBits, false), DL, 0, AC, DT);
 
       // Scale by the type size.
       unsigned TypeSize = AllocTypeSize.getFixedValue();
diff --git a/llvm/test/Analysis/BasicAA/zext-nneg.ll b/llvm/test/Analysis/BasicAA/zext-nneg.ll
new file mode 100644
index 0000000000000..808bb1a8c9d96
--- /dev/null
+++ b/llvm/test/Analysis/BasicAA/zext-nneg.ll
@@ -0,0 +1,181 @@
+; RUN: opt < %s -aa-pipeline=basic-aa -passes=aa-eval -print-all-alias-modref-info -disable-output 2>&1 | FileCheck %s
+
+;; Simple case: a zext nneg can be replaced with a sext. Make sure BasicAA
+;; understands that.
+define void @t1(i32 %a, i32 %b) {
+; CHECK-LABEL: Function: t1
+; CHECK: NoAlias: float* %gep1, float* %gep2
+
+  %1 = alloca [8 x float], align 4
+  %or1 = or i32 %a, 1
+  %2 = sext i32 %or1 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %2
+
+  %shl1 = shl i32 %b, 1
+  %3 = zext nneg i32 %shl1 to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %3
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}
+
+;; A (zext nneg (sext V)) is equivalent to a (zext (sext V)) as long as the
+;; total number of zext+sext bits is the same for both.
+define void @t2(i8 %a, i8 %b) {
+; CHECK-LABEL: Function: t2
+; CHECK: NoAlias: float* %gep1, float* %gep2
+  %1 = alloca [8 x float], align 4
+  %or1 = or i8 %a, 1
+  %2 = sext i8 %or1 to i32
+  %3 = zext i32 %2 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %3
+
+  %shl1 = shl i8 %b, 1
+  %4 = sext i8 %shl1 to i16
+  %5 = zext nneg i16 %4 to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}
+
+;; Here the %a and %b are knowably non-equal. In this cases we can distribute
+;; the zext, preserving the nneg flag, through the shl because it has a nsw flag
+define void @t3(i8 %v) {
+; CHECK-LABEL: Function: t3
+; CHECK: NoAlias: <2 x float>* %gep1, <2 x float>* %gep2
+  %a = or i8 %v, 1
+  %b = and i8 %v, 2
+
+  %1 = alloca [8 x float], align 4
+  %or1 = shl nuw nsw i8 %a, 1
+  %2 = zext nneg i8 %or1 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %2
+
+  %m = mul nsw nuw i8 %b, 2
+  %3 = sext i8 %m to i16
+  %4 = zext i16 %3 to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %4
+
+  load <2 x float>, ptr %gep1
+  load <2 x float>, ptr %gep2
+  ret void
+}
+
+;; This is the same as above, but this time the shl does not have the nsw flag.
+;; the nneg cannot be kept on the zext.
+define void @t4(i8 %v) {
+; CHECK-LABEL: Function: t4
+; CHECK: MayAlias: <2 x float>* %gep1, <2 x float>* %gep2
+  %a = or i8 %v, 1
+  %b = and i8 %v, 2
+
+  %1 = alloca [8 x float], align 4
+  %or1 = shl nuw i8 %a, 1
+  %2 = zext nneg i8 %or1 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %2
+
+  %m = mul nsw nuw i8 %b, 2
+  %3 = sext i8 %m to i16
+  %4 = zext i16 %3 to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %4
+
+  load <2 x float>, ptr %gep1
+  load <2 x float>, ptr %gep2
+  ret void
+}
+
+;; Verify a zext nneg and a zext are understood as the same
+define void @t5(ptr %p, i16 %i) {
+; CHECK-LABEL: Function: t5
+; CHECK: NoAlias: i32* %pi, i32* %pi.next
+  %i1 = zext nneg i16 %i to i32
+  %pi = getelementptr i32, ptr %p, i32 %i1
+
+  %i.next = add i16 %i, 1
+  %i.next2 = zext i16 %i.next to i32
+  %pi.next = getelementptr i32, ptr %p, i32 %i.next2
+
+  load i32, ptr %pi
+  load i32, ptr %pi.next
+  ret void
+}
+
+;; This is not very idiomatic, but still possible, verify the nneg is propagated
+;; outward. and that no alias is correctly identified.
+define void @t6(i8 %a) {
+; CHECK-LABEL: Function: t6
+; CHECK: NoAlias: float* %gep1, float* %gep2
+  %1 = alloca [8 x float], align 4
+  %a.add = add i8 %a, 1
+  %2 = zext nneg i8 %a.add to i16
+  %3 = sext i16 %2 to i32
+  %4 = zext i32 %3 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %4
+
+  %5 = sext i8 %a to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}
+
+;; This is even less idiomatic, but still possible, verify the nneg is not
+;; propagated inward. and that may alias is correctly identified.
+define void @t7(i8 %a) {
+; CHECK-LABEL: Function: t7
+; CHECK: MayAlias: float* %gep1, float* %gep2
+  %1 = alloca [8 x float], align 4
+  %a.add = add i8 %a, 1
+  %2 = zext i8 %a.add to i16
+  %3 = sext i16 %2 to i32
+  %4 = zext nneg i32 %3 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %4
+
+  %5 = sext i8 %a to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}
+
+;; Verify the nneg survives an implicit trunc of fewer bits then the zext.
+define void @t8(i8 %a) {
+; CHECK-LABEL: Function: t8
+; CHECK: NoAlias: float* %gep1, float* %gep2
+  %1 = alloca [8 x float], align 4
+  %a.add = add i8 %a, 1
+  %2 = zext nneg i8 %a.add to i128
+  %gep1 = getelementptr inbounds float, ptr %1, i128 %2
+
+  %3 = sext i8 %a to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %3
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}
+
+;; Ensure that the nneg is never propagated past this trunc and that these
+;; casted values are understood as non-equal.
+define void @t9(i8 %a) {
+; CHECK-LABEL: Function: t9
+; CHECK: MayAlias: float* %gep1, float* %gep2
+  %1 = alloca [8 x float], align 4
+  %a.add = add i8 %a, 1
+  %2 = zext i8 %a.add to i16
+  %3 = trunc i16 %2 to i1
+  %4 = zext nneg i1 %3 to i64
+  %gep1 = getelementptr inbounds float, ptr %1, i64 %4
+
+  %5 = sext i8 %a to i64
+  %gep2 = getelementptr inbounds float, ptr %1, i64 %5
+
+  load float, ptr %gep1
+  load float, ptr %gep2
+  ret void
+}

@@ -268,43 +268,59 @@ struct CastedValue {
unsigned ZExtBits = 0;
unsigned SExtBits = 0;
unsigned TruncBits = 0;
bool IsNonNegative = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/// Whether trunc(V) is non-negative. Should clarify that this does not apply to V.

unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
NewV->getType()->getPrimitiveSizeInBits();
if (ExtendBy <= TruncBits)
return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
// zext<nneg>(trunc(zext(NewV))) == zext<nneg>(trunc(NewV))
// The nneg can be preserved on the outer zext here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// The nneg can be preserved on the outer zext here
// The nneg can be preserved on the outer zext here.

@@ -450,7 +473,8 @@ static LinearExpression GetLinearExpression(

if (isa<ZExtInst>(Val.V))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dyn_cast here instead.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM apart from the nits. (I deleted some other review comments that I changed my mind on.)

@AlexMaclean AlexMaclean merged commit d881bac into llvm:main Jun 4, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants