Skip to content

Commit 1e5334b

Browse files
kartcqMeinersbur
andauthored
[Polly] Data flow reduction detection to cover more cases (#84901)
The base concept is same as existing reduction algorithm where we get the list of candidate pairs <store,load>. But the existing algorithm works only if there is single binary operation between the load and store. Example sum += a[i]; This algorithm extends to work with more than single binary operation as well. It is implemented using data flow reduction detection on basic block level. We propagate the loads, the number of times the load is used(flows into instruction) and binary operation performed until we reach a store. Example sum += a[i] + b[i]; ``` sum(Ld) a[i](Ld) \ + / tmp b[i](Ld) \ + / sum(St) ``` In the above case the candidate pairs are formed by associating sum with all of its load inputs which are sum, a[i] and b[i]. Then check functions are used to filter a valid reduction pair ie {sum,sum}. --------- Co-authored-by: Michael Kruse <[email protected]>
1 parent 1ada235 commit 1e5334b

13 files changed

+651
-91
lines changed

polly/include/polly/ScopBuilder.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -663,19 +663,6 @@ class ScopBuilder final {
663663
/// nullptr if it cannot be hoisted at all.
664664
isl::set getNonHoistableCtx(MemoryAccess *Access, isl::union_map Writes);
665665

666-
/// Collect loads which might form a reduction chain with @p StoreMA.
667-
///
668-
/// Check if the stored value for @p StoreMA is a binary operator with one or
669-
/// two loads as operands. If the binary operand is commutative & associative,
670-
/// used only once (by @p StoreMA) and its load operands are also used only
671-
/// once, we have found a possible reduction chain. It starts at an operand
672-
/// load and includes the binary operator and @p StoreMA.
673-
///
674-
/// Note: We allow only one use to ensure the load and binary operator cannot
675-
/// escape this block or into any other store except @p StoreMA.
676-
void collectCandidateReductionLoads(MemoryAccess *StoreMA,
677-
SmallVectorImpl<MemoryAccess *> &Loads);
678-
679666
/// Build the access relation of all memory accesses of @p Stmt.
680667
void buildAccessRelations(ScopStmt &Stmt);
681668

polly/include/polly/ScopInfo.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,8 @@ class MemoryAccess final {
470470
RT_BOR, ///< Bitwise Or
471471
RT_BXOR, ///< Bitwise XOr
472472
RT_BAND, ///< Bitwise And
473+
474+
RT_BOTTOM, ///< Pseudo type for the data flow analysis
473475
};
474476

475477
using SubscriptsTy = SmallVector<const SCEV *, 4>;
@@ -1139,6 +1141,7 @@ class ScopStmt final {
11391141
friend class ScopBuilder;
11401142

11411143
public:
1144+
using MemoryAccessVec = llvm::SmallVector<MemoryAccess *, 8>;
11421145
/// Create the ScopStmt from a BasicBlock.
11431146
ScopStmt(Scop &parent, BasicBlock &bb, StringRef Name, Loop *SurroundingLoop,
11441147
std::vector<Instruction *> Instructions);
@@ -1206,7 +1209,6 @@ class ScopStmt final {
12061209
/// The memory accesses of this statement.
12071210
///
12081211
/// The only side effects of a statement are its memory accesses.
1209-
using MemoryAccessVec = llvm::SmallVector<MemoryAccess *, 8>;
12101212
MemoryAccessVec MemAccs;
12111213

12121214
/// Mapping from instructions to (scalar) memory accesses.

polly/lib/Analysis/ScopBuilder.cpp

Lines changed: 199 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -2481,8 +2481,8 @@ void ScopBuilder::collectSurroundingLoops(ScopStmt &Stmt) {
24812481
}
24822482

24832483
/// Return the reduction type for a given binary operator.
2484-
static MemoryAccess::ReductionType getReductionType(const BinaryOperator *BinOp,
2485-
const Instruction *Load) {
2484+
static MemoryAccess::ReductionType
2485+
getReductionType(const BinaryOperator *BinOp) {
24862486
if (!BinOp)
24872487
return MemoryAccess::RT_NONE;
24882488
switch (BinOp->getOpcode()) {
@@ -2511,6 +2511,17 @@ static MemoryAccess::ReductionType getReductionType(const BinaryOperator *BinOp,
25112511
}
25122512
}
25132513

2514+
/// @brief Combine two reduction types
2515+
static MemoryAccess::ReductionType
2516+
combineReductionType(MemoryAccess::ReductionType RT0,
2517+
MemoryAccess::ReductionType RT1) {
2518+
if (RT0 == MemoryAccess::RT_BOTTOM)
2519+
return RT1;
2520+
if (RT0 == RT1)
2521+
return RT1;
2522+
return MemoryAccess::RT_NONE;
2523+
}
2524+
25142525
/// True if @p AllAccs intersects with @p MemAccs execpt @p LoadMA and @p
25152526
/// StoreMA
25162527
bool hasIntersectingAccesses(isl::set AllAccs, MemoryAccess *LoadMA,
@@ -2571,47 +2582,206 @@ bool checkCandidatePairAccesses(MemoryAccess *LoadMA, MemoryAccess *StoreMA,
25712582
AllAccsRel = AllAccsRel.intersect_domain(Domain);
25722583
isl::set AllAccs = AllAccsRel.range();
25732584
Valid = !hasIntersectingAccesses(AllAccs, LoadMA, StoreMA, Domain, MemAccs);
2574-
25752585
POLLY_DEBUG(dbgs() << " == The accessed memory is " << (Valid ? "not " : "")
25762586
<< "accessed by other instructions!\n");
25772587
}
2588+
25782589
return Valid;
25792590
}
25802591

25812592
void ScopBuilder::checkForReductions(ScopStmt &Stmt) {
2582-
SmallVector<MemoryAccess *, 2> Loads;
2583-
SmallVector<std::pair<MemoryAccess *, MemoryAccess *>, 4> Candidates;
2593+
// Perform a data flow analysis on the current scop statement to propagate the
2594+
// uses of loaded values. Then check and mark the memory accesses which are
2595+
// part of reduction like chains.
2596+
// During the data flow analysis we use the State variable to keep track of
2597+
// the used "load-instructions" for each instruction in the scop statement.
2598+
// This includes the LLVM-IR of the load and the "number of uses" (or the
2599+
// number of paths in the operand tree which end in this load).
2600+
using StatePairTy = std::pair<unsigned, MemoryAccess::ReductionType>;
2601+
using FlowInSetTy = MapVector<const LoadInst *, StatePairTy>;
2602+
using StateTy = MapVector<const Instruction *, FlowInSetTy>;
2603+
StateTy State;
2604+
2605+
// Invalid loads are loads which have uses we can't track properly in the
2606+
// state map. This includes loads which:
2607+
// o do not form a reduction when they flow into a memory location:
2608+
// (e.g., A[i] = B[i] * 3 and A[i] = A[i] * A[i] + A[i])
2609+
// o are used by a non binary operator or one which is not commutative
2610+
// and associative (e.g., A[i] = A[i] % 3)
2611+
// o might change the control flow (e.g., if (A[i]))
2612+
// o are used in indirect memory accesses (e.g., A[B[i]])
2613+
// o are used outside the current scop statement
2614+
SmallPtrSet<const Instruction *, 8> InvalidLoads;
2615+
SmallVector<BasicBlock *, 8> ScopBlocks;
2616+
BasicBlock *BB = Stmt.getBasicBlock();
2617+
if (BB)
2618+
ScopBlocks.push_back(BB);
2619+
else
2620+
for (BasicBlock *Block : Stmt.getRegion()->blocks())
2621+
ScopBlocks.push_back(Block);
2622+
// Run the data flow analysis for all values in the scop statement
2623+
for (BasicBlock *Block : ScopBlocks) {
2624+
for (Instruction &Inst : *Block) {
2625+
if ((Stmt.getParent())->getStmtFor(&Inst) != &Stmt)
2626+
continue;
2627+
bool UsedOutsideStmt = any_of(Inst.users(), [&Stmt](User *U) {
2628+
return (Stmt.getParent())->getStmtFor(cast<Instruction>(U)) != &Stmt;
2629+
});
2630+
// Treat loads and stores special
2631+
if (auto *Load = dyn_cast<LoadInst>(&Inst)) {
2632+
// Invalidate all loads used which feed into the address of this load.
2633+
if (auto *Ptr = dyn_cast<Instruction>(Load->getPointerOperand())) {
2634+
const auto &It = State.find(Ptr);
2635+
if (It != State.end())
2636+
for (const auto &FlowInSetElem : It->second)
2637+
InvalidLoads.insert(FlowInSetElem.first);
2638+
}
25842639

2585-
// First collect candidate load-store reduction chains by iterating over all
2586-
// stores and collecting possible reduction loads.
2587-
for (MemoryAccess *StoreMA : Stmt) {
2588-
if (StoreMA->isRead())
2589-
continue;
2640+
// If this load is used outside this stmt, invalidate it.
2641+
if (UsedOutsideStmt)
2642+
InvalidLoads.insert(Load);
2643+
2644+
// And indicate that this load uses itself once but without specifying
2645+
// any reduction operator.
2646+
State[Load].insert(
2647+
std::make_pair(Load, std::make_pair(1, MemoryAccess::RT_BOTTOM)));
2648+
continue;
2649+
}
2650+
2651+
if (auto *Store = dyn_cast<StoreInst>(&Inst)) {
2652+
// Invalidate all loads which feed into the address of this store.
2653+
if (const Instruction *Ptr =
2654+
dyn_cast<Instruction>(Store->getPointerOperand())) {
2655+
const auto &It = State.find(Ptr);
2656+
if (It != State.end())
2657+
for (const auto &FlowInSetElem : It->second)
2658+
InvalidLoads.insert(FlowInSetElem.first);
2659+
}
2660+
2661+
// Propagate the uses of the value operand to the store
2662+
if (auto *ValueInst = dyn_cast<Instruction>(Store->getValueOperand()))
2663+
State.insert(std::make_pair(Store, State[ValueInst]));
2664+
continue;
2665+
}
2666+
2667+
// Non load and store instructions are either binary operators or they
2668+
// will invalidate all used loads.
2669+
auto *BinOp = dyn_cast<BinaryOperator>(&Inst);
2670+
MemoryAccess::ReductionType CurRedType = getReductionType(BinOp);
2671+
POLLY_DEBUG(dbgs() << "CurInst: " << Inst << " RT: " << CurRedType
2672+
<< "\n");
2673+
2674+
// Iterate over all operands and propagate their input loads to
2675+
// instruction.
2676+
FlowInSetTy &InstInFlowSet = State[&Inst];
2677+
for (Use &Op : Inst.operands()) {
2678+
auto *OpInst = dyn_cast<Instruction>(Op);
2679+
if (!OpInst)
2680+
continue;
2681+
2682+
POLLY_DEBUG(dbgs().indent(4) << "Op Inst: " << *OpInst << "\n");
2683+
const StateTy::iterator &OpInFlowSetIt = State.find(OpInst);
2684+
if (OpInFlowSetIt == State.end())
2685+
continue;
2686+
2687+
// Iterate over all the input loads of the operand and combine them
2688+
// with the input loads of current instruction.
2689+
FlowInSetTy &OpInFlowSet = OpInFlowSetIt->second;
2690+
for (auto &OpInFlowPair : OpInFlowSet) {
2691+
unsigned OpFlowIn = OpInFlowPair.second.first;
2692+
unsigned InstFlowIn = InstInFlowSet[OpInFlowPair.first].first;
2693+
2694+
MemoryAccess::ReductionType OpRedType = OpInFlowPair.second.second;
2695+
MemoryAccess::ReductionType InstRedType =
2696+
InstInFlowSet[OpInFlowPair.first].second;
2697+
2698+
MemoryAccess::ReductionType NewRedType =
2699+
combineReductionType(OpRedType, CurRedType);
2700+
if (InstFlowIn)
2701+
NewRedType = combineReductionType(NewRedType, InstRedType);
2702+
2703+
POLLY_DEBUG(dbgs().indent(8) << "OpRedType: " << OpRedType << "\n");
2704+
POLLY_DEBUG(dbgs().indent(8) << "NewRedType: " << NewRedType << "\n");
2705+
InstInFlowSet[OpInFlowPair.first] =
2706+
std::make_pair(OpFlowIn + InstFlowIn, NewRedType);
2707+
}
2708+
}
25902709

2591-
Loads.clear();
2592-
collectCandidateReductionLoads(StoreMA, Loads);
2593-
for (MemoryAccess *LoadMA : Loads)
2594-
Candidates.push_back(std::make_pair(LoadMA, StoreMA));
2710+
// If this operation is used outside the stmt, invalidate all the loads
2711+
// which feed into it.
2712+
if (UsedOutsideStmt)
2713+
for (const auto &FlowInSetElem : InstInFlowSet)
2714+
InvalidLoads.insert(FlowInSetElem.first);
2715+
}
25952716
}
25962717

2597-
// Then check each possible candidate pair.
2598-
for (const auto &CandidatePair : Candidates) {
2599-
MemoryAccess *LoadMA = CandidatePair.first;
2600-
MemoryAccess *StoreMA = CandidatePair.second;
2601-
bool Valid = checkCandidatePairAccesses(LoadMA, StoreMA, Stmt.getDomain(),
2602-
Stmt.MemAccs);
2603-
if (!Valid)
2718+
// All used loads are propagated through the whole basic block; now try to
2719+
// find valid reduction-like candidate pairs. These load-store pairs fulfill
2720+
// all reduction like properties with regards to only this load-store chain.
2721+
// We later have to check if the loaded value was invalidated by an
2722+
// instruction not in that chain.
2723+
using MemAccPair = std::pair<MemoryAccess *, MemoryAccess *>;
2724+
DenseMap<MemAccPair, MemoryAccess::ReductionType> ValidCandidates;
2725+
DominatorTree *DT = Stmt.getParent()->getDT();
2726+
2727+
// Iterate over all write memory accesses and check the loads flowing into
2728+
// it for reduction candidate pairs.
2729+
for (MemoryAccess *WriteMA : Stmt.MemAccs) {
2730+
if (WriteMA->isRead())
2731+
continue;
2732+
StoreInst *St = dyn_cast<StoreInst>(WriteMA->getAccessInstruction());
2733+
if (!St)
26042734
continue;
2735+
assert(!St->isVolatile());
2736+
2737+
FlowInSetTy &MaInFlowSet = State[WriteMA->getAccessInstruction()];
2738+
for (auto &MaInFlowSetElem : MaInFlowSet) {
2739+
MemoryAccess *ReadMA = &Stmt.getArrayAccessFor(MaInFlowSetElem.first);
2740+
assert(ReadMA && "Couldn't find memory access for incoming load!");
26052741

2606-
const LoadInst *Load =
2607-
dyn_cast<const LoadInst>(CandidatePair.first->getAccessInstruction());
2608-
MemoryAccess::ReductionType RT =
2609-
getReductionType(dyn_cast<BinaryOperator>(Load->user_back()), Load);
2742+
POLLY_DEBUG(dbgs() << "'" << *ReadMA->getAccessInstruction()
2743+
<< "'\n\tflows into\n'"
2744+
<< *WriteMA->getAccessInstruction() << "'\n\t #"
2745+
<< MaInFlowSetElem.second.first << " times & RT: "
2746+
<< MaInFlowSetElem.second.second << "\n");
26102747

2611-
// If no overlapping access was found we mark the load and store as
2612-
// reduction like.
2613-
LoadMA->markAsReductionLike(RT);
2614-
StoreMA->markAsReductionLike(RT);
2748+
MemoryAccess::ReductionType RT = MaInFlowSetElem.second.second;
2749+
unsigned NumAllowableInFlow = 1;
2750+
2751+
// We allow the load to flow in exactly once for binary reductions
2752+
bool Valid = (MaInFlowSetElem.second.first == NumAllowableInFlow);
2753+
2754+
// Check if we saw a valid chain of binary operators.
2755+
Valid = Valid && RT != MemoryAccess::RT_BOTTOM;
2756+
Valid = Valid && RT != MemoryAccess::RT_NONE;
2757+
2758+
// Then check if the memory accesses allow a reduction.
2759+
Valid = Valid && checkCandidatePairAccesses(
2760+
ReadMA, WriteMA, Stmt.getDomain(), Stmt.MemAccs);
2761+
2762+
// Finally, mark the pair as a candidate or the load as a invalid one.
2763+
if (Valid)
2764+
ValidCandidates[std::make_pair(ReadMA, WriteMA)] = RT;
2765+
else
2766+
InvalidLoads.insert(ReadMA->getAccessInstruction());
2767+
}
2768+
}
2769+
2770+
// In the last step mark the memory accesses of candidate pairs as reduction
2771+
// like if the load wasn't marked invalid in the previous step.
2772+
for (auto &CandidatePair : ValidCandidates) {
2773+
MemoryAccess *LoadMA = CandidatePair.first.first;
2774+
if (InvalidLoads.count(LoadMA->getAccessInstruction()))
2775+
continue;
2776+
POLLY_DEBUG(
2777+
dbgs() << " Load :: "
2778+
<< *((CandidatePair.first.first)->getAccessInstruction())
2779+
<< "\n Store :: "
2780+
<< *((CandidatePair.first.second)->getAccessInstruction())
2781+
<< "\n are marked as reduction like\n");
2782+
MemoryAccess::ReductionType RT = CandidatePair.second;
2783+
CandidatePair.first.first->markAsReductionLike(RT);
2784+
CandidatePair.first.second->markAsReductionLike(RT);
26152785
}
26162786
}
26172787

@@ -2965,52 +3135,6 @@ void ScopBuilder::addInvariantLoads(ScopStmt &Stmt,
29653135
}
29663136
}
29673137

2968-
void ScopBuilder::collectCandidateReductionLoads(
2969-
MemoryAccess *StoreMA, SmallVectorImpl<MemoryAccess *> &Loads) {
2970-
ScopStmt *Stmt = StoreMA->getStatement();
2971-
2972-
auto *Store = dyn_cast<StoreInst>(StoreMA->getAccessInstruction());
2973-
if (!Store)
2974-
return;
2975-
2976-
// Skip if there is not one binary operator between the load and the store
2977-
auto *BinOp = dyn_cast<BinaryOperator>(Store->getValueOperand());
2978-
if (!BinOp)
2979-
return;
2980-
2981-
// Skip if the binary operators has multiple uses
2982-
if (BinOp->getNumUses() != 1)
2983-
return;
2984-
2985-
// Skip if the opcode of the binary operator is not commutative/associative
2986-
if (!BinOp->isCommutative() || !BinOp->isAssociative())
2987-
return;
2988-
2989-
// Skip if the binary operator is outside the current SCoP
2990-
if (BinOp->getParent() != Store->getParent())
2991-
return;
2992-
2993-
// Skip if it is a multiplicative reduction and we disabled them
2994-
if (DisableMultiplicativeReductions &&
2995-
(BinOp->getOpcode() == Instruction::Mul ||
2996-
BinOp->getOpcode() == Instruction::FMul))
2997-
return;
2998-
2999-
// Check the binary operator operands for a candidate load
3000-
auto *PossibleLoad0 = dyn_cast<LoadInst>(BinOp->getOperand(0));
3001-
auto *PossibleLoad1 = dyn_cast<LoadInst>(BinOp->getOperand(1));
3002-
if (!PossibleLoad0 && !PossibleLoad1)
3003-
return;
3004-
3005-
// A load is only a candidate if it cannot escape (thus has only this use)
3006-
if (PossibleLoad0 && PossibleLoad0->getNumUses() == 1)
3007-
if (PossibleLoad0->getParent() == Store->getParent())
3008-
Loads.push_back(&Stmt->getArrayAccessFor(PossibleLoad0));
3009-
if (PossibleLoad1 && PossibleLoad1->getNumUses() == 1)
3010-
if (PossibleLoad1->getParent() == Store->getParent())
3011-
Loads.push_back(&Stmt->getArrayAccessFor(PossibleLoad1));
3012-
}
3013-
30143138
/// Find the canonical scop array info object for a set of invariant load
30153139
/// hoisted loads. The canonical array is the one that corresponds to the
30163140
/// first load in the list of accesses which is used as base pointer of a

polly/lib/Analysis/ScopInfo.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,9 @@ MemoryAccess::getReductionOperatorStr(MemoryAccess::ReductionType RT) {
533533
case MemoryAccess::RT_NONE:
534534
llvm_unreachable("Requested a reduction operator string for a memory "
535535
"access which isn't a reduction");
536+
case MemoryAccess::RT_BOTTOM:
537+
llvm_unreachable("Requested a reduction operator string for a internal "
538+
"reduction type!");
536539
case MemoryAccess::RT_ADD:
537540
return "+";
538541
case MemoryAccess::RT_MUL:
@@ -915,10 +918,15 @@ isl::id MemoryAccess::getId() const { return Id; }
915918

916919
raw_ostream &polly::operator<<(raw_ostream &OS,
917920
MemoryAccess::ReductionType RT) {
918-
if (RT == MemoryAccess::RT_NONE)
921+
switch (RT) {
922+
case MemoryAccess::RT_NONE:
923+
case MemoryAccess::RT_BOTTOM:
919924
OS << "NONE";
920-
else
925+
break;
926+
default:
921927
OS << MemoryAccess::getReductionOperatorStr(RT);
928+
break;
929+
}
922930
return OS;
923931
}
924932

0 commit comments

Comments
 (0)