Skip to content

[Polly] Data flow reduction detection to cover more cases #84901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 30, 2024
13 changes: 0 additions & 13 deletions polly/include/polly/ScopBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -663,19 +663,6 @@ class ScopBuilder final {
/// nullptr if it cannot be hoisted at all.
isl::set getNonHoistableCtx(MemoryAccess *Access, isl::union_map Writes);

/// Collect loads which might form a reduction chain with @p StoreMA.
///
/// Check if the stored value for @p StoreMA is a binary operator with one or
/// two loads as operands. If the binary operand is commutative & associative,
/// used only once (by @p StoreMA) and its load operands are also used only
/// once, we have found a possible reduction chain. It starts at an operand
/// load and includes the binary operator and @p StoreMA.
///
/// Note: We allow only one use to ensure the load and binary operator cannot
/// escape this block or into any other store except @p StoreMA.
void collectCandidateReductionLoads(MemoryAccess *StoreMA,
SmallVectorImpl<MemoryAccess *> &Loads);

/// Build the access relation of all memory accesses of @p Stmt.
void buildAccessRelations(ScopStmt &Stmt);

Expand Down
4 changes: 3 additions & 1 deletion polly/include/polly/ScopInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ class MemoryAccess final {
RT_BOR, ///< Bitwise Or
RT_BXOR, ///< Bitwise XOr
RT_BAND, ///< Bitwise And

RT_BOTTOM, ///< Pseudo type for the data flow analysis
};

using SubscriptsTy = SmallVector<const SCEV *, 4>;
Expand Down Expand Up @@ -1139,6 +1141,7 @@ class ScopStmt final {
friend class ScopBuilder;

public:
using MemoryAccessVec = llvm::SmallVector<MemoryAccess *, 8>;
/// Create the ScopStmt from a BasicBlock.
ScopStmt(Scop &parent, BasicBlock &bb, StringRef Name, Loop *SurroundingLoop,
std::vector<Instruction *> Instructions);
Expand Down Expand Up @@ -1206,7 +1209,6 @@ class ScopStmt final {
/// The memory accesses of this statement.
///
/// The only side effects of a statement are its memory accesses.
using MemoryAccessVec = llvm::SmallVector<MemoryAccess *, 8>;
MemoryAccessVec MemAccs;

/// Mapping from instructions to (scalar) memory accesses.
Expand Down
274 changes: 199 additions & 75 deletions polly/lib/Analysis/ScopBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2481,8 +2481,8 @@ void ScopBuilder::collectSurroundingLoops(ScopStmt &Stmt) {
}

/// Return the reduction type for a given binary operator.
static MemoryAccess::ReductionType getReductionType(const BinaryOperator *BinOp,
const Instruction *Load) {
static MemoryAccess::ReductionType
getReductionType(const BinaryOperator *BinOp) {
if (!BinOp)
return MemoryAccess::RT_NONE;
switch (BinOp->getOpcode()) {
Expand Down Expand Up @@ -2511,6 +2511,17 @@ static MemoryAccess::ReductionType getReductionType(const BinaryOperator *BinOp,
}
}

/// @brief Combine two reduction types
static MemoryAccess::ReductionType
combineReductionType(MemoryAccess::ReductionType RT0,
MemoryAccess::ReductionType RT1) {
if (RT0 == MemoryAccess::RT_BOTTOM)
return RT1;
if (RT0 == RT1)
return RT1;
return MemoryAccess::RT_NONE;
}

/// True if @p AllAccs intersects with @p MemAccs execpt @p LoadMA and @p
/// StoreMA
bool hasIntersectingAccesses(isl::set AllAccs, MemoryAccess *LoadMA,
Expand Down Expand Up @@ -2571,47 +2582,206 @@ bool checkCandidatePairAccesses(MemoryAccess *LoadMA, MemoryAccess *StoreMA,
AllAccsRel = AllAccsRel.intersect_domain(Domain);
isl::set AllAccs = AllAccsRel.range();
Valid = !hasIntersectingAccesses(AllAccs, LoadMA, StoreMA, Domain, MemAccs);

POLLY_DEBUG(dbgs() << " == The accessed memory is " << (Valid ? "not " : "")
<< "accessed by other instructions!\n");
}

return Valid;
}

void ScopBuilder::checkForReductions(ScopStmt &Stmt) {
SmallVector<MemoryAccess *, 2> Loads;
SmallVector<std::pair<MemoryAccess *, MemoryAccess *>, 4> Candidates;
// Perform a data flow analysis on the current scop statement to propagate the
// uses of loaded values. Then check and mark the memory accesses which are
// part of reduction like chains.
// During the data flow analysis we use the State variable to keep track of
// the used "load-instructions" for each instruction in the scop statement.
// This includes the LLVM-IR of the load and the "number of uses" (or the
// number of paths in the operand tree which end in this load).
using StatePairTy = std::pair<unsigned, MemoryAccess::ReductionType>;
using FlowInSetTy = MapVector<const LoadInst *, StatePairTy>;
using StateTy = MapVector<const Instruction *, FlowInSetTy>;
StateTy State;

// Invalid loads are loads which have uses we can't track properly in the
// state map. This includes loads which:
// o do not form a reduction when they flow into a memory location:
// (e.g., A[i] = B[i] * 3 and A[i] = A[i] * A[i] + A[i])
// o are used by a non binary operator or one which is not commutative
// and associative (e.g., A[i] = A[i] % 3)
// o might change the control flow (e.g., if (A[i]))
// o are used in indirect memory accesses (e.g., A[B[i]])
// o are used outside the current scop statement
SmallPtrSet<const Instruction *, 8> InvalidLoads;
SmallVector<BasicBlock *, 8> ScopBlocks;
BasicBlock *BB = Stmt.getBasicBlock();
if (BB)
ScopBlocks.push_back(BB);
else
for (BasicBlock *Block : Stmt.getRegion()->blocks())
ScopBlocks.push_back(Block);
// Run the data flow analysis for all values in the scop statement
for (BasicBlock *Block : ScopBlocks) {
for (Instruction &Inst : *Block) {
if ((Stmt.getParent())->getStmtFor(&Inst) != &Stmt)
continue;
bool UsedOutsideStmt = any_of(Inst.users(), [&Stmt](User *U) {
return (Stmt.getParent())->getStmtFor(cast<Instruction>(U)) != &Stmt;
});
// Treat loads and stores special
if (auto *Load = dyn_cast<LoadInst>(&Inst)) {
// Invalidate all loads used which feed into the address of this load.
if (auto *Ptr = dyn_cast<Instruction>(Load->getPointerOperand())) {
const auto &It = State.find(Ptr);
if (It != State.end())
for (const auto &FlowInSetElem : It->second)
InvalidLoads.insert(FlowInSetElem.first);
}

// First collect candidate load-store reduction chains by iterating over all
// stores and collecting possible reduction loads.
for (MemoryAccess *StoreMA : Stmt) {
if (StoreMA->isRead())
continue;
// If this load is used outside this stmt, invalidate it.
if (UsedOutsideStmt)
InvalidLoads.insert(Load);

// And indicate that this load uses itself once but without specifying
// any reduction operator.
State[Load].insert(
std::make_pair(Load, std::make_pair(1, MemoryAccess::RT_BOTTOM)));
continue;
}

if (auto *Store = dyn_cast<StoreInst>(&Inst)) {
// Invalidate all loads which feed into the address of this store.
if (const Instruction *Ptr =
dyn_cast<Instruction>(Store->getPointerOperand())) {
const auto &It = State.find(Ptr);
if (It != State.end())
for (const auto &FlowInSetElem : It->second)
InvalidLoads.insert(FlowInSetElem.first);
}

// Propagate the uses of the value operand to the store
if (auto *ValueInst = dyn_cast<Instruction>(Store->getValueOperand()))
State.insert(std::make_pair(Store, State[ValueInst]));
continue;
}

// Non load and store instructions are either binary operators or they
// will invalidate all used loads.
auto *BinOp = dyn_cast<BinaryOperator>(&Inst);
MemoryAccess::ReductionType CurRedType = getReductionType(BinOp);
POLLY_DEBUG(dbgs() << "CurInst: " << Inst << " RT: " << CurRedType
<< "\n");

// Iterate over all operands and propagate their input loads to
// instruction.
FlowInSetTy &InstInFlowSet = State[&Inst];
for (Use &Op : Inst.operands()) {
auto *OpInst = dyn_cast<Instruction>(Op);
if (!OpInst)
continue;

POLLY_DEBUG(dbgs().indent(4) << "Op Inst: " << *OpInst << "\n");
const StateTy::iterator &OpInFlowSetIt = State.find(OpInst);
if (OpInFlowSetIt == State.end())
continue;

// Iterate over all the input loads of the operand and combine them
// with the input loads of current instruction.
FlowInSetTy &OpInFlowSet = OpInFlowSetIt->second;
for (auto &OpInFlowPair : OpInFlowSet) {
unsigned OpFlowIn = OpInFlowPair.second.first;
unsigned InstFlowIn = InstInFlowSet[OpInFlowPair.first].first;

MemoryAccess::ReductionType OpRedType = OpInFlowPair.second.second;
MemoryAccess::ReductionType InstRedType =
InstInFlowSet[OpInFlowPair.first].second;

MemoryAccess::ReductionType NewRedType =
combineReductionType(OpRedType, CurRedType);
if (InstFlowIn)
NewRedType = combineReductionType(NewRedType, InstRedType);

POLLY_DEBUG(dbgs().indent(8) << "OpRedType: " << OpRedType << "\n");
POLLY_DEBUG(dbgs().indent(8) << "NewRedType: " << NewRedType << "\n");
InstInFlowSet[OpInFlowPair.first] =
std::make_pair(OpFlowIn + InstFlowIn, NewRedType);
}
}

Loads.clear();
collectCandidateReductionLoads(StoreMA, Loads);
for (MemoryAccess *LoadMA : Loads)
Candidates.push_back(std::make_pair(LoadMA, StoreMA));
// If this operation is used outside the stmt, invalidate all the loads
// which feed into it.
if (UsedOutsideStmt)
for (const auto &FlowInSetElem : InstInFlowSet)
InvalidLoads.insert(FlowInSetElem.first);
}
}

// Then check each possible candidate pair.
for (const auto &CandidatePair : Candidates) {
MemoryAccess *LoadMA = CandidatePair.first;
MemoryAccess *StoreMA = CandidatePair.second;
bool Valid = checkCandidatePairAccesses(LoadMA, StoreMA, Stmt.getDomain(),
Stmt.MemAccs);
if (!Valid)
// All used loads are propagated through the whole basic block; now try to
// find valid reduction-like candidate pairs. These load-store pairs fulfill
// all reduction like properties with regards to only this load-store chain.
// We later have to check if the loaded value was invalidated by an
// instruction not in that chain.
using MemAccPair = std::pair<MemoryAccess *, MemoryAccess *>;
DenseMap<MemAccPair, MemoryAccess::ReductionType> ValidCandidates;
DominatorTree *DT = Stmt.getParent()->getDT();

// Iterate over all write memory accesses and check the loads flowing into
// it for reduction candidate pairs.
for (MemoryAccess *WriteMA : Stmt.MemAccs) {
if (WriteMA->isRead())
continue;
StoreInst *St = dyn_cast<StoreInst>(WriteMA->getAccessInstruction());
if (!St)
continue;
assert(!St->isVolatile());

FlowInSetTy &MaInFlowSet = State[WriteMA->getAccessInstruction()];
for (auto &MaInFlowSetElem : MaInFlowSet) {
MemoryAccess *ReadMA = &Stmt.getArrayAccessFor(MaInFlowSetElem.first);
assert(ReadMA && "Couldn't find memory access for incoming load!");

const LoadInst *Load =
dyn_cast<const LoadInst>(CandidatePair.first->getAccessInstruction());
MemoryAccess::ReductionType RT =
getReductionType(dyn_cast<BinaryOperator>(Load->user_back()), Load);
POLLY_DEBUG(dbgs() << "'" << *ReadMA->getAccessInstruction()
<< "'\n\tflows into\n'"
<< *WriteMA->getAccessInstruction() << "'\n\t #"
<< MaInFlowSetElem.second.first << " times & RT: "
<< MaInFlowSetElem.second.second << "\n");

// If no overlapping access was found we mark the load and store as
// reduction like.
LoadMA->markAsReductionLike(RT);
StoreMA->markAsReductionLike(RT);
MemoryAccess::ReductionType RT = MaInFlowSetElem.second.second;
unsigned NumAllowableInFlow = 1;

// We allow the load to flow in exactly once for binary reductions
bool Valid = (MaInFlowSetElem.second.first == NumAllowableInFlow);

// Check if we saw a valid chain of binary operators.
Valid = Valid && RT != MemoryAccess::RT_BOTTOM;
Valid = Valid && RT != MemoryAccess::RT_NONE;

// Then check if the memory accesses allow a reduction.
Valid = Valid && checkCandidatePairAccesses(
ReadMA, WriteMA, Stmt.getDomain(), Stmt.MemAccs);

// Finally, mark the pair as a candidate or the load as a invalid one.
if (Valid)
ValidCandidates[std::make_pair(ReadMA, WriteMA)] = RT;
else
InvalidLoads.insert(ReadMA->getAccessInstruction());
}
}

// In the last step mark the memory accesses of candidate pairs as reduction
// like if the load wasn't marked invalid in the previous step.
for (auto &CandidatePair : ValidCandidates) {
MemoryAccess *LoadMA = CandidatePair.first.first;
if (InvalidLoads.count(LoadMA->getAccessInstruction()))
continue;
POLLY_DEBUG(
dbgs() << " Load :: "
<< *((CandidatePair.first.first)->getAccessInstruction())
<< "\n Store :: "
<< *((CandidatePair.first.second)->getAccessInstruction())
<< "\n are marked as reduction like\n");
MemoryAccess::ReductionType RT = CandidatePair.second;
CandidatePair.first.first->markAsReductionLike(RT);
CandidatePair.first.second->markAsReductionLike(RT);
}
}

Expand Down Expand Up @@ -2965,52 +3135,6 @@ void ScopBuilder::addInvariantLoads(ScopStmt &Stmt,
}
}

void ScopBuilder::collectCandidateReductionLoads(
MemoryAccess *StoreMA, SmallVectorImpl<MemoryAccess *> &Loads) {
ScopStmt *Stmt = StoreMA->getStatement();

auto *Store = dyn_cast<StoreInst>(StoreMA->getAccessInstruction());
if (!Store)
return;

// Skip if there is not one binary operator between the load and the store
auto *BinOp = dyn_cast<BinaryOperator>(Store->getValueOperand());
if (!BinOp)
return;

// Skip if the binary operators has multiple uses
if (BinOp->getNumUses() != 1)
return;

// Skip if the opcode of the binary operator is not commutative/associative
if (!BinOp->isCommutative() || !BinOp->isAssociative())
return;

// Skip if the binary operator is outside the current SCoP
if (BinOp->getParent() != Store->getParent())
return;

// Skip if it is a multiplicative reduction and we disabled them
if (DisableMultiplicativeReductions &&
(BinOp->getOpcode() == Instruction::Mul ||
BinOp->getOpcode() == Instruction::FMul))
return;

// Check the binary operator operands for a candidate load
auto *PossibleLoad0 = dyn_cast<LoadInst>(BinOp->getOperand(0));
auto *PossibleLoad1 = dyn_cast<LoadInst>(BinOp->getOperand(1));
if (!PossibleLoad0 && !PossibleLoad1)
return;

// A load is only a candidate if it cannot escape (thus has only this use)
if (PossibleLoad0 && PossibleLoad0->getNumUses() == 1)
if (PossibleLoad0->getParent() == Store->getParent())
Loads.push_back(&Stmt->getArrayAccessFor(PossibleLoad0));
if (PossibleLoad1 && PossibleLoad1->getNumUses() == 1)
if (PossibleLoad1->getParent() == Store->getParent())
Loads.push_back(&Stmt->getArrayAccessFor(PossibleLoad1));
}

/// Find the canonical scop array info object for a set of invariant load
/// hoisted loads. The canonical array is the one that corresponds to the
/// first load in the list of accesses which is used as base pointer of a
Expand Down
12 changes: 10 additions & 2 deletions polly/lib/Analysis/ScopInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,9 @@ MemoryAccess::getReductionOperatorStr(MemoryAccess::ReductionType RT) {
case MemoryAccess::RT_NONE:
llvm_unreachable("Requested a reduction operator string for a memory "
"access which isn't a reduction");
case MemoryAccess::RT_BOTTOM:
llvm_unreachable("Requested a reduction operator string for a internal "
"reduction type!");
case MemoryAccess::RT_ADD:
return "+";
case MemoryAccess::RT_MUL:
Expand Down Expand Up @@ -915,10 +918,15 @@ isl::id MemoryAccess::getId() const { return Id; }

raw_ostream &polly::operator<<(raw_ostream &OS,
MemoryAccess::ReductionType RT) {
if (RT == MemoryAccess::RT_NONE)
switch (RT) {
case MemoryAccess::RT_NONE:
case MemoryAccess::RT_BOTTOM:
OS << "NONE";
else
break;
default:
OS << MemoryAccess::getReductionOperatorStr(RT);
break;
}
return OS;
}

Expand Down
Loading