61
61
62
62
#include " llvm/CodeGen/ComplexDeinterleavingPass.h"
63
63
#include " llvm/ADT/MapVector.h"
64
+ #include " llvm/ADT/SetVector.h"
64
65
#include " llvm/ADT/Statistic.h"
65
66
#include " llvm/Analysis/TargetLibraryInfo.h"
66
67
#include " llvm/Analysis/TargetTransformInfo.h"
@@ -274,6 +275,13 @@ class ComplexDeinterleavingGraph {
274
275
// / `llvm.vector.reduce.fadd` when unroll factor isn't one.
275
276
MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
276
277
278
+ // / In the case of reductions in unrolled loops, the %OutsideUser from
279
+ // / ReductionInfo is an add instruction that precedes the reduction.
280
+ // / UnrollInfo pairs values together if they are both operands of the same
281
+ // / add. This pairing info is then used to add the resulting complex
282
+ // / operations together before the final reduction.
283
+ MapVector<Value *, Value *> UnrollInfo;
284
+
277
285
// / In the process of detecting a reduction, we consider a pair of
278
286
// / %ReductionOP, which we refer to as real and imag (or vice versa), and
279
287
// / traverse the use-tree to detect complex operations. As this is a reduction
@@ -2253,8 +2261,31 @@ void ComplexDeinterleavingGraph::processReductionSingle(
2253
2261
auto *FinalReduction = ReductionInfo[Real].second ;
2254
2262
Builder.SetInsertPoint (&*FinalReduction->getParent ()->getFirstInsertionPt ());
2255
2263
2256
- auto *AddReduce = Builder.CreateAddReduce (OperationReplacement);
2264
+ Value *Other;
2265
+ bool EraseFinalReductionHere = false ;
2266
+ if (match (FinalReduction, m_c_Add (m_Specific (Real), m_Value (Other)))) {
2267
+ UnrollInfo[Real] = OperationReplacement;
2268
+ if (!UnrollInfo.contains (Other) || !FinalReduction->hasOneUser ())
2269
+ return ;
2270
+
2271
+ auto *User = *FinalReduction->user_begin ();
2272
+ if (!match (User, m_Intrinsic<Intrinsic::vector_reduce_add>()))
2273
+ return ;
2274
+
2275
+ FinalReduction = cast<Instruction>(User);
2276
+ Builder.SetInsertPoint (FinalReduction);
2277
+ OperationReplacement =
2278
+ Builder.CreateAdd (OperationReplacement, UnrollInfo[Other]);
2279
+
2280
+ UnrollInfo.erase (Real);
2281
+ UnrollInfo.erase (Other);
2282
+ EraseFinalReductionHere = true ;
2283
+ }
2284
+
2285
+ Value *AddReduce = Builder.CreateAddReduce (OperationReplacement);
2257
2286
FinalReduction->replaceAllUsesWith (AddReduce);
2287
+ if (EraseFinalReductionHere)
2288
+ FinalReduction->eraseFromParent ();
2258
2289
}
2259
2290
2260
2291
void ComplexDeinterleavingGraph::processReductionOperation (
@@ -2299,7 +2330,7 @@ void ComplexDeinterleavingGraph::processReductionOperation(
2299
2330
}
2300
2331
2301
2332
void ComplexDeinterleavingGraph::replaceNodes () {
2302
- SmallVector <Instruction *, 16 > DeadInstrRoots;
2333
+ SmallSetVector <Instruction *, 16 > DeadInstrRoots;
2303
2334
for (auto *RootInstruction : OrderedRoots) {
2304
2335
// Check if this potential root went through check process and we can
2305
2336
// deinterleave it
@@ -2316,20 +2347,25 @@ void ComplexDeinterleavingGraph::replaceNodes() {
2316
2347
auto *RootImag = cast<Instruction>(RootNode->Imag );
2317
2348
ReductionInfo[RootReal].first ->removeIncomingValue (BackEdge);
2318
2349
ReductionInfo[RootImag].first ->removeIncomingValue (BackEdge);
2319
- DeadInstrRoots.push_back (RootReal);
2320
- DeadInstrRoots.push_back (RootImag);
2350
+ DeadInstrRoots.insert (RootReal);
2351
+ DeadInstrRoots.insert (RootImag);
2321
2352
} else if (RootNode->Operation ==
2322
2353
ComplexDeinterleavingOperation::ReductionSingle) {
2323
2354
auto *RootInst = cast<Instruction>(RootNode->Real );
2324
2355
ReductionInfo[RootInst].first ->removeIncomingValue (BackEdge);
2325
- DeadInstrRoots.push_back (ReductionInfo[RootInst].second );
2356
+ DeadInstrRoots.insert (ReductionInfo[RootInst].second );
2326
2357
} else {
2327
2358
assert (R && " Unable to find replacement for RootInstruction" );
2328
- DeadInstrRoots.push_back (RootInstruction);
2359
+ DeadInstrRoots.insert (RootInstruction);
2329
2360
RootInstruction->replaceAllUsesWith (R);
2330
2361
}
2331
2362
}
2332
2363
2364
+ assert (UnrollInfo.empty () &&
2365
+ " UnrollInfo should be empty after replacing all nodes" );
2366
+
2367
+ for (auto *I : DeadInstrRoots)
2368
+ dbgs () << " Dead Instr Root: " << *I << " \n " ;
2333
2369
for (auto *I : DeadInstrRoots)
2334
2370
RecursivelyDeleteTriviallyDeadInstructions (I, TLI);
2335
2371
}
0 commit comments