Skip to content

[Polly] Data flow reduction detection to cover more cases #84901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 30, 2024
2 changes: 1 addition & 1 deletion polly/include/polly/ScopBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ class ScopBuilder final {
/// results will escape during execution of the loop nest. We basically check
/// here that no other memory access can access the same memory as the
/// potential reduction.
void checkForReductions(ScopStmt &Stmt);
void checkForReductions(ScopStmt &Stmt, BasicBlock *Block);

/// Verify that all required invariant loads have been hoisted.
///
Expand Down
4 changes: 3 additions & 1 deletion polly/include/polly/ScopInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ class MemoryAccess final {
RT_BOR, ///< Bitwise Or
RT_BXOR, ///< Bitwise XOr
RT_BAND, ///< Bitwise And

RT_BOTTOM, ///< Pseudo type for the data flow analysis
};

using SubscriptsTy = SmallVector<const SCEV *, 4>;
Expand Down Expand Up @@ -1139,6 +1141,7 @@ class ScopStmt final {
friend class ScopBuilder;

public:
using MemoryAccessVec = llvm::SmallVector<MemoryAccess *, 8>;
/// Create the ScopStmt from a BasicBlock.
ScopStmt(Scop &parent, BasicBlock &bb, StringRef Name, Loop *SurroundingLoop,
std::vector<Instruction *> Instructions);
Expand Down Expand Up @@ -1206,7 +1209,6 @@ class ScopStmt final {
/// The memory accesses of this statement.
///
/// The only side effects of a statement are its memory accesses.
using MemoryAccessVec = llvm::SmallVector<MemoryAccess *, 8>;
MemoryAccessVec MemAccs;

/// Mapping from instructions to (scalar) memory accesses.
Expand Down
272 changes: 195 additions & 77 deletions polly/lib/Analysis/ScopBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2480,8 +2480,8 @@ void ScopBuilder::collectSurroundingLoops(ScopStmt &Stmt) {
}

/// Return the reduction type for a given binary operator.
static MemoryAccess::ReductionType getReductionType(const BinaryOperator *BinOp,
const Instruction *Load) {
static MemoryAccess::ReductionType
getReductionType(const BinaryOperator *BinOp) {
if (!BinOp)
return MemoryAccess::RT_NONE;
switch (BinOp->getOpcode()) {
Expand Down Expand Up @@ -2510,6 +2510,17 @@ static MemoryAccess::ReductionType getReductionType(const BinaryOperator *BinOp,
}
}

/// @brief Combine two reduction types
static MemoryAccess::ReductionType
combineReductionType(MemoryAccess::ReductionType RT0,
MemoryAccess::ReductionType RT1) {
if (RT0 == MemoryAccess::RT_BOTTOM)
return RT1;
if (RT0 == RT1)
return RT1;
return MemoryAccess::RT_NONE;
}

/// True if @p AllAccs intersects with @p MemAccs execpt @p LoadMA and @p
/// StoreMA
bool hasIntersectingAccesses(isl::set AllAccs, MemoryAccess *LoadMA,
Expand Down Expand Up @@ -2568,49 +2579,196 @@ bool checkCandidatePairAccesses(MemoryAccess *LoadMA, MemoryAccess *StoreMA,
// Finally, check if they are no other instructions accessing this memory
isl::map AllAccsRel = LoadAccs.unite(StoreAccs);
AllAccsRel = AllAccsRel.intersect_domain(Domain);

isl::set AllAccs = AllAccsRel.range();

Valid = !hasIntersectingAccesses(AllAccs, LoadMA, StoreMA, Domain, MemAccs);

LLVM_DEBUG(dbgs() << " == The accessed memory is " << (Valid ? "not " : "")
<< "accessed by other instructions!\n");
}

return Valid;
}

void ScopBuilder::checkForReductions(ScopStmt &Stmt) {
SmallVector<MemoryAccess *, 2> Loads;
SmallVector<std::pair<MemoryAccess *, MemoryAccess *>, 4> Candidates;
/// Perform a data flow analysis on the current basic block to propagate the
/// uses of loaded values. Then check and mark the memory accesses which are
/// part of reduction like chains.
///
/// NOTE: This assumes independent blocks and breaks otherwise.
void ScopBuilder::checkForReductions(ScopStmt &Stmt, BasicBlock *Block) {
// During the data flow anaylis we use the State variable to keep track of
// the used "load-instructions" for each instruction in the basic block.
// This includes the LLVM-IR of the load and the "number of uses" (or the
// number of paths in the operand tree which end in this load).
using StatePairTy = std::pair<unsigned, MemoryAccess::ReductionType>;
using FlowInSetTy = MapVector<const LoadInst *, StatePairTy>;
using StateTy = MapVector<const Instruction *, FlowInSetTy>;
StateTy State;

// Invalid loads are loads which have uses we can't track properly in the
// state map. This includes loads which:
// o do not form a reduction when they flow into a memory location:
// (e.g., A[i] = B[i] * 3 and A[i] = A[i] * A[i] + A[i])
// o are used by a non binary operator or one which is not commutative
// and associative (e.g., A[i] = A[i] % 3)
// o might change the control flow (e.g., if (A[i]))
// o are used in indirect memory accesses (e.g., A[B[i]])
// o are used outside the current basic block
SmallPtrSet<const Instruction *, 8> InvalidLoads;

// Run the data flow analysis for all values in the basic block
for (Instruction &Inst : *Block) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since a BasicBlock consists of multiple ScopStmts, this really should enumerate the instructions of the statement, not of the block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to use Stmt.getInstructions() here. But for scop statements with multiple blocks, the getInstructions() API returns the instructions of first block only. I am not sure if it bug with the function or expected behavior.

Copy link
Member

@Meinersbur Meinersbur May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is expected behavior. getInstructions() cannot capture the entire control flow in a RegionStmt so it only represents the instructions of the header block. See

// Block statements and the entry blocks of region statement are code
// generated from instruction lists. This allow us to optimize the
// instructions that belong to a certain scop statement. As the code
// structure of region statements might be arbitrary complex, optimizing the
// instruction list is not yet supported.
if (Stmt.isBlockStmt() || (Stmt.isRegionStmt() && Stmt.getEntryBlock() == BB))
for (Instruction *Inst : Stmt.getInstructions())
copyInstruction(Stmt, Inst, BBMap, LTS, NewAccesses);
else
for (Instruction &Inst : *BB)
copyInstruction(Stmt, &Inst, BBMap, LTS, NewAccesses);
for rationale and how to handle.

Standard behavior is to not try to analyze the other BasicBlocks because of the control flow (conditional- and loop-execution). The simple cases (single, unconditional execution) are in the header.

bool UsedOutsideBlock = any_of(Inst.users(), [Block](User *U) {
return cast<Instruction>(U)->getParent() != Block;
});

// Treat loads and stores special
if (auto *Load = dyn_cast<LoadInst>(&Inst)) {
// Invalidate all loads used which feed into the address of this load.
if (auto *Ptr = dyn_cast<Instruction>(Load->getPointerOperand())) {
const auto &It = State.find(Ptr);
if (It != State.end())
for (const auto &FlowInSetElem : It->second)
InvalidLoads.insert(FlowInSetElem.first);
}

// If this load is used outside this block, invalidate it.
if (UsedOutsideBlock)
InvalidLoads.insert(Load);

// First collect candidate load-store reduction chains by iterating over all
// stores and collecting possible reduction loads.
for (MemoryAccess *StoreMA : Stmt) {
if (StoreMA->isRead())
// And indicate that this load uses itself once but without specifying
// any reduction operator.
State[Load].insert(
std::make_pair(Load, std::make_pair(1, MemoryAccess::RT_BOTTOM)));
continue;
}

Loads.clear();
collectCandidateReductionLoads(StoreMA, Loads);
for (MemoryAccess *LoadMA : Loads)
Candidates.push_back(std::make_pair(LoadMA, StoreMA));
}
if (auto *Store = dyn_cast<StoreInst>(&Inst)) {
// Invalidate all loads which feed into the address of this store.
if (const Instruction *Ptr =
dyn_cast<Instruction>(Store->getPointerOperand())) {
const auto &It = State.find(Ptr);
if (It != State.end())
for (const auto &FlowInSetElem : It->second)
InvalidLoads.insert(FlowInSetElem.first);
}

// Then check each possible candidate pair.
for (const auto &CandidatePair : Candidates) {
MemoryAccess *LoadMA = CandidatePair.first;
MemoryAccess *StoreMA = CandidatePair.second;
bool Valid = checkCandidatePairAccesses(LoadMA, StoreMA, Stmt.getDomain(),
Stmt.MemAccs);
if (!Valid)
// Propagate the uses of the value operand to the store
if (auto *ValueInst = dyn_cast<Instruction>(Store->getValueOperand()))
State.insert(std::make_pair(Store, State[ValueInst]));
continue;
}

// Non load and store instructions are either binary operators or they will
// invalidate all used loads.
auto *BinOp = dyn_cast<BinaryOperator>(&Inst);
auto CurRedType = getReductionType(BinOp);
LLVM_DEBUG(dbgs() << "CurInst: " << Inst << " RT: " << CurRedType << "\n");

// Iterate over all operands and propagate their input loads to instruction.
FlowInSetTy &InstInFlowSet = State[&Inst];
for (Use &Op : Inst.operands()) {
auto *OpInst = dyn_cast<Instruction>(Op);
if (!OpInst)
continue;

LLVM_DEBUG(dbgs().indent(4) << "Op Inst: " << *OpInst << "\n");
const StateTy::iterator &OpInFlowSetIt = State.find(OpInst);
if (OpInFlowSetIt == State.end())
continue;

// Iterate over all the input loads of the operand and combine them
// with the input loads of current instruction.
FlowInSetTy &OpInFlowSet = OpInFlowSetIt->second;
for (auto &OpInFlowPair : OpInFlowSet) {
unsigned OpFlowIn = OpInFlowPair.second.first;
unsigned InstFlowIn = InstInFlowSet[OpInFlowPair.first].first;

auto OpRedType = OpInFlowPair.second.second;
auto InstRedType = InstInFlowSet[OpInFlowPair.first].second;

auto NewRedType = combineReductionType(OpRedType, CurRedType);
if (InstFlowIn)
NewRedType = combineReductionType(NewRedType, InstRedType);

LLVM_DEBUG(dbgs().indent(8) << "OpRedType: " << OpRedType << "\n");
LLVM_DEBUG(dbgs().indent(8) << "NewRedType: " << NewRedType << "\n");
InstInFlowSet[OpInFlowPair.first] =
std::make_pair(OpFlowIn + InstFlowIn, NewRedType);
}
}

// If this operation is used outside the block, invalidate all the loads
// which feed into it.
if (UsedOutsideBlock)
for (const auto &FlowInSetElem : InstInFlowSet)
InvalidLoads.insert(FlowInSetElem.first);
}

// All used loads are propagated through the whole basic block; now try to
// find valid reduction like candidate pairs. These load-store pairs fulfill
// all reduction like properties with regards to only this load-store chain.
// We later have to check if the loaded value was invalidated by an
// instruction not in that chain.
using MemAccPair = std::pair<MemoryAccess *, MemoryAccess *>;
DenseMap<MemAccPair, MemoryAccess::ReductionType> ValidCandidates;
DominatorTree *DT = Stmt.getParent()->getDT();

// Iterate over all write memory accesses and check the loads flowing into
// it for reduction candidate pairs.
for (MemoryAccess *WriteMA : Stmt.MemAccs) {
if (WriteMA->isRead())
continue;
StoreInst *St = dyn_cast<StoreInst>(WriteMA->getAccessInstruction());
if (!St || St->isVolatile())
continue;

FlowInSetTy &MaInFlowSet = State[WriteMA->getAccessInstruction()];
bool Valid = false;

for (auto &MaInFlowSetElem : MaInFlowSet) {
MemoryAccess *ReadMA = &Stmt.getArrayAccessFor(MaInFlowSetElem.first);
assert(ReadMA && "Couldn't find memory access for incoming load!");

const LoadInst *Load =
dyn_cast<const LoadInst>(CandidatePair.first->getAccessInstruction());
MemoryAccess::ReductionType RT =
getReductionType(dyn_cast<BinaryOperator>(Load->user_back()), Load);
LLVM_DEBUG(dbgs() << "'" << *ReadMA->getAccessInstruction()
<< "'\n\tflows into\n'"
<< *WriteMA->getAccessInstruction() << "'\n\t #"
<< MaInFlowSetElem.second.first << " times & RT: "
<< MaInFlowSetElem.second.second << "\n");

// If no overlapping access was found we mark the load and store as
// reduction like.
LoadMA->markAsReductionLike(RT);
StoreMA->markAsReductionLike(RT);
MemoryAccess::ReductionType RT = MaInFlowSetElem.second.second;
unsigned NumAllowableInFlow = 1;

// We allow the load to flow in exactly once for binary reductions
Valid = (MaInFlowSetElem.second.first == NumAllowableInFlow);

// Check if we saw a valid chain of binary operators.
Valid = Valid && RT != MemoryAccess::RT_BOTTOM;
Valid = Valid && RT != MemoryAccess::RT_NONE;

// Then check if the memory accesses allow a reduction.
Valid = Valid && checkCandidatePairAccesses(
ReadMA, WriteMA, Stmt.getDomain(), Stmt.MemAccs);

// Finally, mark the pair as a candidate or the load as a invalid one.
if (Valid)
ValidCandidates[std::make_pair(ReadMA, WriteMA)] = RT;
else
InvalidLoads.insert(ReadMA->getAccessInstruction());
}
}

// In the last step mark the memory accesses of candidate pairs as reduction
// like if the load wasn't marked invalid in the previous step.
for (auto &CandidatePair : ValidCandidates) {
MemoryAccess *LoadMA = CandidatePair.first.first;
if (InvalidLoads.count(LoadMA->getAccessInstruction()))
continue;

MemoryAccess::ReductionType RT = CandidatePair.second;
CandidatePair.first.first->markAsReductionLike(RT);
CandidatePair.first.second->markAsReductionLike(RT);
}
}

Expand Down Expand Up @@ -2963,52 +3121,6 @@ void ScopBuilder::addInvariantLoads(ScopStmt &Stmt,
}
}

void ScopBuilder::collectCandidateReductionLoads(
MemoryAccess *StoreMA, SmallVectorImpl<MemoryAccess *> &Loads) {
ScopStmt *Stmt = StoreMA->getStatement();

auto *Store = dyn_cast<StoreInst>(StoreMA->getAccessInstruction());
if (!Store)
return;

// Skip if there is not one binary operator between the load and the store
auto *BinOp = dyn_cast<BinaryOperator>(Store->getValueOperand());
if (!BinOp)
return;

// Skip if the binary operators has multiple uses
if (BinOp->getNumUses() != 1)
return;

// Skip if the opcode of the binary operator is not commutative/associative
if (!BinOp->isCommutative() || !BinOp->isAssociative())
return;

// Skip if the binary operator is outside the current SCoP
if (BinOp->getParent() != Store->getParent())
return;

// Skip if it is a multiplicative reduction and we disabled them
if (DisableMultiplicativeReductions &&
(BinOp->getOpcode() == Instruction::Mul ||
BinOp->getOpcode() == Instruction::FMul))
return;

// Check the binary operator operands for a candidate load
auto *PossibleLoad0 = dyn_cast<LoadInst>(BinOp->getOperand(0));
auto *PossibleLoad1 = dyn_cast<LoadInst>(BinOp->getOperand(1));
if (!PossibleLoad0 && !PossibleLoad1)
return;

// A load is only a candidate if it cannot escape (thus has only this use)
if (PossibleLoad0 && PossibleLoad0->getNumUses() == 1)
if (PossibleLoad0->getParent() == Store->getParent())
Loads.push_back(&Stmt->getArrayAccessFor(PossibleLoad0));
if (PossibleLoad1 && PossibleLoad1->getNumUses() == 1)
if (PossibleLoad1->getParent() == Store->getParent())
Loads.push_back(&Stmt->getArrayAccessFor(PossibleLoad1));
}

/// Find the canonical scop array info object for a set of invariant load
/// hoisted loads. The canonical array is the one that corresponds to the
/// first load in the list of accesses which is used as base pointer of a
Expand Down Expand Up @@ -3593,8 +3705,14 @@ void ScopBuilder::buildScop(Region &R, AssumptionCache &AC) {
buildDomain(Stmt);
buildAccessRelations(Stmt);

if (DetectReductions)
checkForReductions(Stmt);
if (DetectReductions) {
BasicBlock *BB = Stmt.getBasicBlock();
if (BB)
checkForReductions(Stmt, BB);
else
for (BasicBlock *Block : Stmt.getRegion()->blocks())
checkForReductions(Stmt, Block);
}
}

// Check early for a feasible runtime context.
Expand Down
12 changes: 10 additions & 2 deletions polly/lib/Analysis/ScopInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,9 @@ MemoryAccess::getReductionOperatorStr(MemoryAccess::ReductionType RT) {
case MemoryAccess::RT_NONE:
llvm_unreachable("Requested a reduction operator string for a memory "
"access which isn't a reduction");
case MemoryAccess::RT_BOTTOM:
llvm_unreachable("Requested a reduction operator string for a internal "
"reduction type!");
case MemoryAccess::RT_ADD:
return "+";
case MemoryAccess::RT_MUL:
Expand Down Expand Up @@ -914,10 +917,15 @@ isl::id MemoryAccess::getId() const { return Id; }

raw_ostream &polly::operator<<(raw_ostream &OS,
MemoryAccess::ReductionType RT) {
if (RT == MemoryAccess::RT_NONE)
switch (RT) {
case MemoryAccess::RT_NONE:
case MemoryAccess::RT_BOTTOM:
OS << "NONE";
else
break;
default:
OS << MemoryAccess::getReductionOperatorStr(RT);
break;
}
return OS;
}

Expand Down
Loading