Skip to content

Commit c11ea44

Browse files
[ValueTracking] Add matchSimpleBinaryIntrinsicRecurrence helper
Similarly to what it is being done to match simple recurrence cycle relations, attempt to match value-accumulating recurrences of kind: ``` %umax.acc = phi i8 [ %umax, %backedge ], [ %a, %entry ] %umax = call i8 @llvm.umax.i8(i8 %umax.acc, i8 %b) ``` Preliminary work to let InstCombine avoid folding such recurrences, so that simple loop-invariant computation may get hoisted. Minor opportunity to refactor out code as well.
1 parent ac7e391 commit c11ea44

File tree

3 files changed

+111
-30
lines changed

3 files changed

+111
-30
lines changed

llvm/include/llvm/Analysis/ValueTracking.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/IR/FMF.h"
2222
#include "llvm/IR/InstrTypes.h"
2323
#include "llvm/IR/Instructions.h"
24+
#include "llvm/IR/IntrinsicInst.h"
2425
#include "llvm/IR/Intrinsics.h"
2526
#include "llvm/Support/Compiler.h"
2627
#include <cassert>
@@ -965,6 +966,21 @@ LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
965966
LLVM_ABI bool matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
966967
Value *&Start, Value *&Step);
967968

969+
/// Attempt to match a simple value-accumulating recurrence of the form:
970+
/// %llvm.intrinsic.acc = phi Ty [%Init, %Entry], [%llvm.intrinsic, %backedge]
971+
/// %llvm.intrinsic = call Ty @llvm.intrinsic(%OtherOp, %llvm.intrinsic.acc)
972+
/// OR
973+
/// %llvm.intrinsic.acc = phi Ty [%Init, %Entry], [%llvm.intrinsic, %backedge]
974+
/// %llvm.intrinsic = call Ty @llvm.intrinsic(%llvm.intrinsic.acc, %OtherOp)
975+
///
976+
/// The recurrence relation is of kind:
977+
/// X_0 = %a (initial value),
978+
/// X_i = call @llvm.binary.intrinsic(X_i-1, %b)
979+
/// Where %b is not required to be loop-invariant.
980+
LLVM_ABI bool matchSimpleBinaryIntrinsicRecurrence(const IntrinsicInst *I,
981+
PHINode *&P, Value *&Init,
982+
Value *&OtherOp);
983+
968984
/// Return true if RHS is known to be implied true by LHS. Return false if
969985
/// RHS is known to be implied false by LHS. Otherwise, return std::nullopt if
970986
/// no implication can be made. A & B must be i1 (boolean) values or a vector of

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9070,46 +9070,43 @@ llvm::canConvertToMinOrMaxIntrinsic(ArrayRef<Value *> VL) {
90709070
return {Intrinsic::not_intrinsic, false};
90719071
}
90729072

9073-
bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
9074-
Value *&Start, Value *&Step) {
9073+
template <typename InstTy>
9074+
static bool matchTwoInputRecurrence(const PHINode *PN, InstTy *&Inst,
9075+
Value *&Init, Value *&OtherOp) {
90759076
// Handle the case of a simple two-predecessor recurrence PHI.
90769077
// There's a lot more that could theoretically be done here, but
90779078
// this is sufficient to catch some interesting cases.
90789079
// TODO: Expand list -- gep, uadd.sat etc.
9079-
if (P->getNumIncomingValues() != 2)
9080+
if (PN->getNumIncomingValues() != 2)
90809081
return false;
90819082

9082-
for (unsigned i = 0; i != 2; ++i) {
9083-
Value *L = P->getIncomingValue(i);
9084-
Value *R = P->getIncomingValue(!i);
9085-
auto *LU = dyn_cast<BinaryOperator>(L);
9086-
if (!LU)
9087-
continue;
9088-
Value *LL = LU->getOperand(0);
9089-
Value *LR = LU->getOperand(1);
9090-
9091-
// Find a recurrence.
9092-
if (LL == P)
9093-
L = LR;
9094-
else if (LR == P)
9095-
L = LL;
9096-
else
9097-
continue; // Check for recurrence with L and R flipped.
9098-
9099-
// We have matched a recurrence of the form:
9100-
// %iv = [R, %entry], [%iv.next, %backedge]
9101-
// %iv.next = binop %iv, L
9102-
// OR
9103-
// %iv = [R, %entry], [%iv.next, %backedge]
9104-
// %iv.next = binop L, %iv
9105-
BO = LU;
9106-
Start = R;
9107-
Step = L;
9108-
return true;
9083+
for (unsigned I = 0; I != 2; ++I) {
9084+
if (auto *Operation = dyn_cast<InstTy>(PN->getIncomingValue(I))) {
9085+
Value *LHS = Operation->getOperand(0);
9086+
Value *RHS = Operation->getOperand(1);
9087+
if (LHS != PN && RHS != PN)
9088+
continue;
9089+
9090+
Inst = Operation;
9091+
Init = PN->getIncomingValue(!I);
9092+
OtherOp = (LHS == PN) ? RHS : LHS;
9093+
return true;
9094+
}
91099095
}
91109096
return false;
91119097
}
91129098

9099+
bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
9100+
Value *&Start, Value *&Step) {
9101+
// We try to match a recurrence of the form:
9102+
// %iv = [Start, %entry], [%iv.next, %backedge]
9103+
// %iv.next = binop %iv, Step
9104+
// Or:
9105+
// %iv = [Start, %entry], [%iv.next, %backedge]
9106+
// %iv.next = binop Step, %iv
9107+
return matchTwoInputRecurrence(P, BO, Start, Step);
9108+
}
9109+
91139110
bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
91149111
Value *&Start, Value *&Step) {
91159112
BinaryOperator *BO = nullptr;
@@ -9119,6 +9116,22 @@ bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
91199116
return P && matchSimpleRecurrence(P, BO, Start, Step) && BO == I;
91209117
}
91219118

9119+
bool llvm::matchSimpleBinaryIntrinsicRecurrence(const IntrinsicInst *I,
9120+
PHINode *&P, Value *&Init,
9121+
Value *&OtherOp) {
9122+
// Binary intrinsics only supported for now.
9123+
if (I->arg_size() != 2 || I->getType() != I->getArgOperand(0)->getType() ||
9124+
I->getType() != I->getArgOperand(1)->getType())
9125+
return false;
9126+
9127+
IntrinsicInst *II = nullptr;
9128+
P = dyn_cast<PHINode>(I->getArgOperand(0));
9129+
if (!P)
9130+
P = dyn_cast<PHINode>(I->getArgOperand(1));
9131+
9132+
return P && matchTwoInputRecurrence(P, II, Init, OtherOp) && II == I;
9133+
}
9134+
91229135
/// Return true if "icmp Pred LHS RHS" is always true.
91239136
static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
91249137
const Value *RHS) {

llvm/unittests/Analysis/ValueTrackingTest.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,6 +1257,58 @@ TEST_F(ValueTrackingTest, computePtrAlignment) {
12571257
EXPECT_EQ(getKnownAlignment(A, DL, CxtI3, &AC, &DT), Align(16));
12581258
}
12591259

1260+
TEST_F(ValueTrackingTest, MatchBinaryIntrinsicRecurrenceUMax) {
1261+
auto M = parseModule(R"(
1262+
define i8 @test(i8 %a, i8 %b) {
1263+
entry:
1264+
br label %loop
1265+
loop:
1266+
%iv = phi i8 [ %iv.next, %loop ], [ 0, %entry ]
1267+
%umax.acc = phi i8 [ %umax, %loop ], [ %a, %entry ]
1268+
%umax = call i8 @llvm.umax.i8(i8 %umax.acc, i8 %b)
1269+
%iv.next = add nuw i8 %iv, 1
1270+
%cmp = icmp ult i8 %iv.next, 10
1271+
br i1 %cmp, label %loop, label %exit
1272+
exit:
1273+
ret i8 %umax
1274+
}
1275+
)");
1276+
1277+
auto *F = M->getFunction("test");
1278+
auto *II = &cast<IntrinsicInst>(findInstructionByName(F, "umax"));
1279+
auto *UMaxAcc = &cast<PHINode>(findInstructionByName(F, "umax.acc"));
1280+
PHINode *PN;
1281+
Value *Init, *OtherOp;
1282+
EXPECT_TRUE(matchSimpleBinaryIntrinsicRecurrence(II, PN, Init, OtherOp));
1283+
EXPECT_EQ(UMaxAcc, PN);
1284+
EXPECT_EQ(F->getArg(0), Init);
1285+
EXPECT_EQ(F->getArg(1), OtherOp);
1286+
}
1287+
1288+
TEST_F(ValueTrackingTest, MatchBinaryIntrinsicRecurrenceNegativeFSHR) {
1289+
auto M = parseModule(R"(
1290+
define i8 @test(i8 %a, i8 %b, i8 %c) {
1291+
entry:
1292+
br label %loop
1293+
loop:
1294+
%iv = phi i8 [ %iv.next, %loop ], [ 0, %entry ]
1295+
%fshr.acc = phi i8 [ %fshr, %loop ], [ %a, %entry ]
1296+
%fshr = call i8 @llvm.fshr.i8(i8 %fshr.acc, i8 %b, i8 %c)
1297+
%iv.next = add nuw i8 %iv, 1
1298+
%cmp = icmp ult i8 %iv.next, 10
1299+
br i1 %cmp, label %loop, label %exit
1300+
exit:
1301+
ret i8 %fshr
1302+
}
1303+
)");
1304+
1305+
auto *F = M->getFunction("test");
1306+
auto *II = &cast<IntrinsicInst>(findInstructionByName(F, "fshr"));
1307+
PHINode *PN;
1308+
Value *Init, *OtherOp;
1309+
EXPECT_FALSE(matchSimpleBinaryIntrinsicRecurrence(II, PN, Init, OtherOp));
1310+
}
1311+
12601312
TEST_F(ComputeKnownBitsTest, ComputeKnownBits) {
12611313
parseAssembly(
12621314
"define i32 @test(i32 %a, i32 %b) {\n"

0 commit comments

Comments
 (0)