Skip to content

Commit 12760e6

Browse files
committed
[llvm] Fix crash when complex deinterleaving operates on an unrolled loop
1 parent 3a975d6 commit 12760e6

File tree

1 file changed

+42
-6
lines changed

1 file changed

+42
-6
lines changed

llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161

6262
#include "llvm/CodeGen/ComplexDeinterleavingPass.h"
6363
#include "llvm/ADT/MapVector.h"
64+
#include "llvm/ADT/SetVector.h"
6465
#include "llvm/ADT/Statistic.h"
6566
#include "llvm/Analysis/TargetLibraryInfo.h"
6667
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -274,6 +275,13 @@ class ComplexDeinterleavingGraph {
274275
/// `llvm.vector.reduce.fadd` when unroll factor isn't one.
275276
MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
276277

278+
/// In the case of reductions in unrolled loops, the %OutsideUser from
279+
/// ReductionInfo is an add instruction that precedes the reduction.
280+
/// UnrollInfo pairs values together if they are both operands of the same
281+
/// add. This pairing info is then used to add the resulting complex
282+
/// operations together before the final reduction.
283+
MapVector<Value *, Value *> UnrollInfo;
284+
277285
/// In the process of detecting a reduction, we consider a pair of
278286
/// %ReductionOP, which we refer to as real and imag (or vice versa), and
279287
/// traverse the use-tree to detect complex operations. As this is a reduction
@@ -2253,8 +2261,31 @@ void ComplexDeinterleavingGraph::processReductionSingle(
22532261
auto *FinalReduction = ReductionInfo[Real].second;
22542262
Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
22552263

2256-
auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
2264+
Value *Other;
2265+
bool EraseFinalReductionHere = false;
2266+
if (match(FinalReduction, m_c_Add(m_Specific(Real), m_Value(Other)))) {
2267+
UnrollInfo[Real] = OperationReplacement;
2268+
if (!UnrollInfo.contains(Other) || !FinalReduction->hasOneUser())
2269+
return;
2270+
2271+
auto *User = *FinalReduction->user_begin();
2272+
if (!match(User, m_Intrinsic<Intrinsic::vector_reduce_add>()))
2273+
return;
2274+
2275+
FinalReduction = cast<Instruction>(User);
2276+
Builder.SetInsertPoint(FinalReduction);
2277+
OperationReplacement =
2278+
Builder.CreateAdd(OperationReplacement, UnrollInfo[Other]);
2279+
2280+
UnrollInfo.erase(Real);
2281+
UnrollInfo.erase(Other);
2282+
EraseFinalReductionHere = true;
2283+
}
2284+
2285+
Value *AddReduce = Builder.CreateAddReduce(OperationReplacement);
22572286
FinalReduction->replaceAllUsesWith(AddReduce);
2287+
if (EraseFinalReductionHere)
2288+
FinalReduction->eraseFromParent();
22582289
}
22592290

22602291
void ComplexDeinterleavingGraph::processReductionOperation(
@@ -2299,7 +2330,7 @@ void ComplexDeinterleavingGraph::processReductionOperation(
22992330
}
23002331

23012332
void ComplexDeinterleavingGraph::replaceNodes() {
2302-
SmallVector<Instruction *, 16> DeadInstrRoots;
2333+
SmallSetVector<Instruction *, 16> DeadInstrRoots;
23032334
for (auto *RootInstruction : OrderedRoots) {
23042335
// Check if this potential root went through check process and we can
23052336
// deinterleave it
@@ -2316,20 +2347,25 @@ void ComplexDeinterleavingGraph::replaceNodes() {
23162347
auto *RootImag = cast<Instruction>(RootNode->Imag);
23172348
ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
23182349
ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2319-
DeadInstrRoots.push_back(RootReal);
2320-
DeadInstrRoots.push_back(RootImag);
2350+
DeadInstrRoots.insert(RootReal);
2351+
DeadInstrRoots.insert(RootImag);
23212352
} else if (RootNode->Operation ==
23222353
ComplexDeinterleavingOperation::ReductionSingle) {
23232354
auto *RootInst = cast<Instruction>(RootNode->Real);
23242355
ReductionInfo[RootInst].first->removeIncomingValue(BackEdge);
2325-
DeadInstrRoots.push_back(ReductionInfo[RootInst].second);
2356+
DeadInstrRoots.insert(ReductionInfo[RootInst].second);
23262357
} else {
23272358
assert(R && "Unable to find replacement for RootInstruction");
2328-
DeadInstrRoots.push_back(RootInstruction);
2359+
DeadInstrRoots.insert(RootInstruction);
23292360
RootInstruction->replaceAllUsesWith(R);
23302361
}
23312362
}
23322363

2364+
assert(UnrollInfo.empty() &&
2365+
"UnrollInfo should be empty after replacing all nodes");
2366+
2367+
for (auto *I : DeadInstrRoots)
2368+
dbgs() << "Dead Instr Root: " << *I << "\n";
23332369
for (auto *I : DeadInstrRoots)
23342370
RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
23352371
}

0 commit comments

Comments
 (0)