Skip to content

Reduction series : Refactor reduction detection code #72343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 48 additions & 30 deletions polly/lib/Analysis/ScopBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2510,6 +2510,48 @@ static MemoryAccess::ReductionType getReductionType(const BinaryOperator *BinOp,
}
}

/// True if @p AllAccs intersects with @p MemAccs execpt @p LoadMA and @p
/// StoreMA
bool hasIntersectingAccesses(isl::set AllAccs, MemoryAccess *LoadMA,
MemoryAccess *StoreMA, isl::set Domain,
SmallVector<MemoryAccess *, 8> &MemAccs) {
bool HasIntersectingAccs = false;
for (MemoryAccess *MA : MemAccs) {
if (MA == LoadMA || MA == StoreMA)
continue;

isl::map AccRel = MA->getAccessRelation().intersect_domain(Domain);
isl::set Accs = AccRel.range();

if (AllAccs.has_equal_space(Accs)) {
isl::set OverlapAccs = Accs.intersect(AllAccs);
bool DoesIntersect = !OverlapAccs.is_empty();
HasIntersectingAccs |= DoesIntersect;
}
}
return HasIntersectingAccs;
}

/// Test if the accesses of @p LoadMA and @p StoreMA can form a reduction
bool checkCandidatePairAccesses(MemoryAccess *LoadMA, MemoryAccess *StoreMA,
isl::set Domain,
SmallVector<MemoryAccess *, 8> &MemAccs) {
isl::map LoadAccs = LoadMA->getAccessRelation();
isl::map StoreAccs = StoreMA->getAccessRelation();

// Skip those with obviously unequal base addresses.
bool Valid = LoadAccs.has_equal_space(StoreAccs);

// And check if the remaining for overlap with other memory accesses.
if (Valid) {
isl::map AllAccsRel = LoadAccs.unite(StoreAccs);
AllAccsRel = AllAccsRel.intersect_domain(Domain);
isl::set AllAccs = AllAccsRel.range();
Valid = !hasIntersectingAccesses(AllAccs, LoadMA, StoreMA, Domain, MemAccs);
}
return Valid;
}

void ScopBuilder::checkForReductions(ScopStmt &Stmt) {
SmallVector<MemoryAccess *, 2> Loads;
SmallVector<std::pair<MemoryAccess *, MemoryAccess *>, 4> Candidates;
Expand All @@ -2528,34 +2570,10 @@ void ScopBuilder::checkForReductions(ScopStmt &Stmt) {

// Then check each possible candidate pair.
for (const auto &CandidatePair : Candidates) {
bool Valid = true;
isl::map LoadAccs = CandidatePair.first->getAccessRelation();
isl::map StoreAccs = CandidatePair.second->getAccessRelation();

// Skip those with obviously unequal base addresses.
if (!LoadAccs.has_equal_space(StoreAccs)) {
continue;
}

// And check if the remaining for overlap with other memory accesses.
isl::map AllAccsRel = LoadAccs.unite(StoreAccs);
AllAccsRel = AllAccsRel.intersect_domain(Stmt.getDomain());
isl::set AllAccs = AllAccsRel.range();

for (MemoryAccess *MA : Stmt) {
if (MA == CandidatePair.first || MA == CandidatePair.second)
continue;

isl::map AccRel =
MA->getAccessRelation().intersect_domain(Stmt.getDomain());
isl::set Accs = AccRel.range();

if (AllAccs.has_equal_space(Accs)) {
isl::set OverlapAccs = Accs.intersect(AllAccs);
Valid = Valid && OverlapAccs.is_empty();
}
}

MemoryAccess *LoadMA = CandidatePair.first;
MemoryAccess *StoreMA = CandidatePair.second;
bool Valid = checkCandidatePairAccesses(LoadMA, StoreMA, Stmt.getDomain(),
Stmt.MemAccs);
if (!Valid)
continue;

Expand All @@ -2566,8 +2584,8 @@ void ScopBuilder::checkForReductions(ScopStmt &Stmt) {

// If no overlapping access was found we mark the load and store as
// reduction like.
CandidatePair.first->markAsReductionLike(RT);
CandidatePair.second->markAsReductionLike(RT);
LoadMA->markAsReductionLike(RT);
StoreMA->markAsReductionLike(RT);
}
}

Expand Down