Skip to content

[BasicAA] Consider 'nneg' flag when comparing CastedValues #94129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions llvm/lib/Analysis/BasicAliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,43 +268,60 @@ struct CastedValue {
unsigned ZExtBits = 0;
unsigned SExtBits = 0;
unsigned TruncBits = 0;
/// Whether trunc(V) is non-negative.
bool IsNonNegative = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/// Whether trunc(V) is non-negative. Should clarify that this does not apply to V.


explicit CastedValue(const Value *V) : V(V) {}
explicit CastedValue(const Value *V, unsigned ZExtBits, unsigned SExtBits,
unsigned TruncBits)
: V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits) {}
unsigned TruncBits, bool IsNonNegative)
: V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits),
IsNonNegative(IsNonNegative) {}

unsigned getBitWidth() const {
return V->getType()->getPrimitiveSizeInBits() - TruncBits + ZExtBits +
SExtBits;
}

CastedValue withValue(const Value *NewV) const {
return CastedValue(NewV, ZExtBits, SExtBits, TruncBits);
CastedValue withValue(const Value *NewV, bool PreserveNonNeg) const {
return CastedValue(NewV, ZExtBits, SExtBits, TruncBits,
IsNonNegative && PreserveNonNeg);
}

/// Replace V with zext(NewV)
CastedValue withZExtOfValue(const Value *NewV) const {
CastedValue withZExtOfValue(const Value *NewV, bool ZExtNonNegative) const {
unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
NewV->getType()->getPrimitiveSizeInBits();
if (ExtendBy <= TruncBits)
return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
// zext<nneg>(trunc(zext(NewV))) == zext<nneg>(trunc(NewV))
// The nneg can be preserved on the outer zext here.
return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
IsNonNegative);

// zext(sext(zext(NewV))) == zext(zext(zext(NewV)))
ExtendBy -= TruncBits;
return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0, 0);
// zext<nneg>(zext(NewV)) == zext(NewV)
// zext(zext<nneg>(NewV)) == zext<nneg>(NewV)
// The nneg can be preserved from the inner zext here but must be dropped
// from the outer.
return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0, 0,
ZExtNonNegative);
}

/// Replace V with sext(NewV)
CastedValue withSExtOfValue(const Value *NewV) const {
unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
NewV->getType()->getPrimitiveSizeInBits();
if (ExtendBy <= TruncBits)
return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
// zext<nneg>(trunc(sext(NewV))) == zext<nneg>(trunc(NewV))
// The nneg can be preserved on the outer zext here
return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
IsNonNegative);

// zext(sext(sext(NewV)))
ExtendBy -= TruncBits;
return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy, 0);
// zext<nneg>(sext(sext(NewV))) = zext<nneg>(sext(NewV))
// The nneg can be preserved on the outer zext here
return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy, 0, IsNonNegative);
}

APInt evaluateWith(APInt N) const {
Expand Down Expand Up @@ -333,8 +350,15 @@ struct CastedValue {
}

bool hasSameCastsAs(const CastedValue &Other) const {
return ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
TruncBits == Other.TruncBits;
if (ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
TruncBits == Other.TruncBits)
return true;
// If either CastedValue has a nneg zext then the sext/zext bits are
// interchangable for that value.
if (IsNonNegative || Other.IsNonNegative)
return (ZExtBits + SExtBits == Other.ZExtBits + Other.SExtBits &&
TruncBits == Other.TruncBits);
return false;
}
};

Expand Down Expand Up @@ -410,21 +434,21 @@ static LinearExpression GetLinearExpression(

[[fallthrough]];
case Instruction::Add: {
E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
Depth + 1, AC, DT);
E.Offset += RHS;
E.IsNSW &= NSW;
break;
}
case Instruction::Sub: {
E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
Depth + 1, AC, DT);
E.Offset -= RHS;
E.IsNSW &= NSW;
break;
}
case Instruction::Mul:
E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
Depth + 1, AC, DT)
.mul(RHS, NSW);
break;
Expand All @@ -437,7 +461,7 @@ static LinearExpression GetLinearExpression(
if (RHS.getLimitedValue() > Val.getBitWidth())
return Val;

E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
E = GetLinearExpression(Val.withValue(BOp->getOperand(0), NSW), DL,
Depth + 1, AC, DT);
E.Offset <<= RHS.getLimitedValue();
E.Scale <<= RHS.getLimitedValue();
Expand All @@ -448,10 +472,10 @@ static LinearExpression GetLinearExpression(
}
}

if (isa<ZExtInst>(Val.V))
if (const auto *ZExt = dyn_cast<ZExtInst>(Val.V))
return GetLinearExpression(
Val.withZExtOfValue(cast<CastInst>(Val.V)->getOperand(0)),
DL, Depth + 1, AC, DT);
Val.withZExtOfValue(ZExt->getOperand(0), ZExt->hasNonNeg()), DL,
Depth + 1, AC, DT);

if (isa<SExtInst>(Val.V))
return GetLinearExpression(
Expand Down Expand Up @@ -673,7 +697,7 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
unsigned SExtBits = IndexSize > Width ? IndexSize - Width : 0;
unsigned TruncBits = IndexSize < Width ? Width - IndexSize : 0;
LinearExpression LE = GetLinearExpression(
CastedValue(Index, 0, SExtBits, TruncBits), DL, 0, AC, DT);
CastedValue(Index, 0, SExtBits, TruncBits, false), DL, 0, AC, DT);

// Scale by the type size.
unsigned TypeSize = AllocTypeSize.getFixedValue();
Expand Down
181 changes: 181 additions & 0 deletions llvm/test/Analysis/BasicAA/zext-nneg.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
; RUN: opt < %s -aa-pipeline=basic-aa -passes=aa-eval -print-all-alias-modref-info -disable-output 2>&1 | FileCheck %s

;; Simple case: a zext nneg can be replaced with a sext. Make sure BasicAA
;; understands that.
define void @t1(i32 %a, i32 %b) {
; CHECK-LABEL: Function: t1
; CHECK: NoAlias: float* %gep1, float* %gep2

%1 = alloca [8 x float], align 4
%or1 = or i32 %a, 1
%2 = sext i32 %or1 to i64
%gep1 = getelementptr inbounds float, ptr %1, i64 %2

%shl1 = shl i32 %b, 1
%3 = zext nneg i32 %shl1 to i64
%gep2 = getelementptr inbounds float, ptr %1, i64 %3

load float, ptr %gep1
load float, ptr %gep2
ret void
}

;; A (zext nneg (sext V)) is equivalent to a (zext (sext V)) as long as the
;; total number of zext+sext bits is the same for both.
define void @t2(i8 %a, i8 %b) {
; CHECK-LABEL: Function: t2
; CHECK: NoAlias: float* %gep1, float* %gep2
%1 = alloca [8 x float], align 4
%or1 = or i8 %a, 1
%2 = sext i8 %or1 to i32
%3 = zext i32 %2 to i64
%gep1 = getelementptr inbounds float, ptr %1, i64 %3

%shl1 = shl i8 %b, 1
%4 = sext i8 %shl1 to i16
%5 = zext nneg i16 %4 to i64
%gep2 = getelementptr inbounds float, ptr %1, i64 %5

load float, ptr %gep1
load float, ptr %gep2
ret void
}

;; Here the %a and %b are knowably non-equal. In this cases we can distribute
;; the zext, preserving the nneg flag, through the shl because it has a nsw flag
define void @t3(i8 %v) {
; CHECK-LABEL: Function: t3
; CHECK: NoAlias: <2 x float>* %gep1, <2 x float>* %gep2
%a = or i8 %v, 1
%b = and i8 %v, 2

%1 = alloca [8 x float], align 4
%or1 = shl nuw nsw i8 %a, 1
%2 = zext nneg i8 %or1 to i64
%gep1 = getelementptr inbounds float, ptr %1, i64 %2

%m = mul nsw nuw i8 %b, 2
%3 = sext i8 %m to i16
%4 = zext i16 %3 to i64
%gep2 = getelementptr inbounds float, ptr %1, i64 %4

load <2 x float>, ptr %gep1
load <2 x float>, ptr %gep2
ret void
}

;; This is the same as above, but this time the shl does not have the nsw flag.
;; the nneg cannot be kept on the zext.
define void @t4(i8 %v) {
; CHECK-LABEL: Function: t4
; CHECK: MayAlias: <2 x float>* %gep1, <2 x float>* %gep2
%a = or i8 %v, 1
%b = and i8 %v, 2

%1 = alloca [8 x float], align 4
%or1 = shl nuw i8 %a, 1
%2 = zext nneg i8 %or1 to i64
%gep1 = getelementptr inbounds float, ptr %1, i64 %2

%m = mul nsw nuw i8 %b, 2
%3 = sext i8 %m to i16
%4 = zext i16 %3 to i64
%gep2 = getelementptr inbounds float, ptr %1, i64 %4

load <2 x float>, ptr %gep1
load <2 x float>, ptr %gep2
ret void
}

;; Verify a zext nneg and a zext are understood as the same
define void @t5(ptr %p, i16 %i) {
; CHECK-LABEL: Function: t5
; CHECK: NoAlias: i32* %pi, i32* %pi.next
%i1 = zext nneg i16 %i to i32
%pi = getelementptr i32, ptr %p, i32 %i1

%i.next = add i16 %i, 1
%i.next2 = zext i16 %i.next to i32
%pi.next = getelementptr i32, ptr %p, i32 %i.next2

load i32, ptr %pi
load i32, ptr %pi.next
ret void
}

;; This is not very idiomatic, but still possible, verify the nneg is propagated
;; outward. and that no alias is correctly identified.
define void @t6(i8 %a) {
; CHECK-LABEL: Function: t6
; CHECK: NoAlias: float* %gep1, float* %gep2
%1 = alloca [8 x float], align 4
%a.add = add i8 %a, 1
%2 = zext nneg i8 %a.add to i16
%3 = sext i16 %2 to i32
%4 = zext i32 %3 to i64
%gep1 = getelementptr inbounds float, ptr %1, i64 %4

%5 = sext i8 %a to i64
%gep2 = getelementptr inbounds float, ptr %1, i64 %5

load float, ptr %gep1
load float, ptr %gep2
ret void
}

;; This is even less idiomatic, but still possible, verify the nneg is not
;; propagated inward. and that may alias is correctly identified.
define void @t7(i8 %a) {
; CHECK-LABEL: Function: t7
; CHECK: MayAlias: float* %gep1, float* %gep2
%1 = alloca [8 x float], align 4
%a.add = add i8 %a, 1
%2 = zext i8 %a.add to i16
%3 = sext i16 %2 to i32
%4 = zext nneg i32 %3 to i64
%gep1 = getelementptr inbounds float, ptr %1, i64 %4

%5 = sext i8 %a to i64
%gep2 = getelementptr inbounds float, ptr %1, i64 %5

load float, ptr %gep1
load float, ptr %gep2
ret void
}

;; Verify the nneg survives an implicit trunc of fewer bits then the zext.
define void @t8(i8 %a) {
; CHECK-LABEL: Function: t8
; CHECK: NoAlias: float* %gep1, float* %gep2
%1 = alloca [8 x float], align 4
%a.add = add i8 %a, 1
%2 = zext nneg i8 %a.add to i128
%gep1 = getelementptr inbounds float, ptr %1, i128 %2

%3 = sext i8 %a to i64
%gep2 = getelementptr inbounds float, ptr %1, i64 %3

load float, ptr %gep1
load float, ptr %gep2
ret void
}

;; Ensure that the nneg is never propagated past this trunc and that these
;; casted values are understood as non-equal.
define void @t9(i8 %a) {
; CHECK-LABEL: Function: t9
; CHECK: MayAlias: float* %gep1, float* %gep2
%1 = alloca [8 x float], align 4
%a.add = add i8 %a, 1
%2 = zext i8 %a.add to i16
%3 = trunc i16 %2 to i1
%4 = zext nneg i1 %3 to i64
%gep1 = getelementptr inbounds float, ptr %1, i64 %4

%5 = sext i8 %a to i64
%gep2 = getelementptr inbounds float, ptr %1, i64 %5

load float, ptr %gep1
load float, ptr %gep2
ret void
}
Loading