@@ -2481,8 +2481,8 @@ void ScopBuilder::collectSurroundingLoops(ScopStmt &Stmt) {
2481
2481
}
2482
2482
2483
2483
// / Return the reduction type for a given binary operator.
2484
- static MemoryAccess::ReductionType getReductionType ( const BinaryOperator *BinOp,
2485
- const Instruction *Load ) {
2484
+ static MemoryAccess::ReductionType
2485
+ getReductionType ( const BinaryOperator *BinOp ) {
2486
2486
if (!BinOp)
2487
2487
return MemoryAccess::RT_NONE;
2488
2488
switch (BinOp->getOpcode ()) {
@@ -2511,6 +2511,17 @@ static MemoryAccess::ReductionType getReductionType(const BinaryOperator *BinOp,
2511
2511
}
2512
2512
}
2513
2513
2514
+ // / @brief Combine two reduction types
2515
+ static MemoryAccess::ReductionType
2516
+ combineReductionType (MemoryAccess::ReductionType RT0,
2517
+ MemoryAccess::ReductionType RT1) {
2518
+ if (RT0 == MemoryAccess::RT_BOTTOM)
2519
+ return RT1;
2520
+ if (RT0 == RT1)
2521
+ return RT1;
2522
+ return MemoryAccess::RT_NONE;
2523
+ }
2524
+
2514
2525
// / True if @p AllAccs intersects with @p MemAccs execpt @p LoadMA and @p
2515
2526
// / StoreMA
2516
2527
bool hasIntersectingAccesses (isl::set AllAccs, MemoryAccess *LoadMA,
@@ -2571,47 +2582,206 @@ bool checkCandidatePairAccesses(MemoryAccess *LoadMA, MemoryAccess *StoreMA,
2571
2582
AllAccsRel = AllAccsRel.intersect_domain (Domain);
2572
2583
isl::set AllAccs = AllAccsRel.range ();
2573
2584
Valid = !hasIntersectingAccesses (AllAccs, LoadMA, StoreMA, Domain, MemAccs);
2574
-
2575
2585
POLLY_DEBUG (dbgs () << " == The accessed memory is " << (Valid ? " not " : " " )
2576
2586
<< " accessed by other instructions!\n " );
2577
2587
}
2588
+
2578
2589
return Valid;
2579
2590
}
2580
2591
2581
2592
void ScopBuilder::checkForReductions (ScopStmt &Stmt) {
2582
- SmallVector<MemoryAccess *, 2 > Loads;
2583
- SmallVector<std::pair<MemoryAccess *, MemoryAccess *>, 4 > Candidates;
2593
+ // Perform a data flow analysis on the current scop statement to propagate the
2594
+ // uses of loaded values. Then check and mark the memory accesses which are
2595
+ // part of reduction like chains.
2596
+ // During the data flow analysis we use the State variable to keep track of
2597
+ // the used "load-instructions" for each instruction in the scop statement.
2598
+ // This includes the LLVM-IR of the load and the "number of uses" (or the
2599
+ // number of paths in the operand tree which end in this load).
2600
+ using StatePairTy = std::pair<unsigned , MemoryAccess::ReductionType>;
2601
+ using FlowInSetTy = MapVector<const LoadInst *, StatePairTy>;
2602
+ using StateTy = MapVector<const Instruction *, FlowInSetTy>;
2603
+ StateTy State;
2604
+
2605
+ // Invalid loads are loads which have uses we can't track properly in the
2606
+ // state map. This includes loads which:
2607
+ // o do not form a reduction when they flow into a memory location:
2608
+ // (e.g., A[i] = B[i] * 3 and A[i] = A[i] * A[i] + A[i])
2609
+ // o are used by a non binary operator or one which is not commutative
2610
+ // and associative (e.g., A[i] = A[i] % 3)
2611
+ // o might change the control flow (e.g., if (A[i]))
2612
+ // o are used in indirect memory accesses (e.g., A[B[i]])
2613
+ // o are used outside the current scop statement
2614
+ SmallPtrSet<const Instruction *, 8 > InvalidLoads;
2615
+ SmallVector<BasicBlock *, 8 > ScopBlocks;
2616
+ BasicBlock *BB = Stmt.getBasicBlock ();
2617
+ if (BB)
2618
+ ScopBlocks.push_back (BB);
2619
+ else
2620
+ for (BasicBlock *Block : Stmt.getRegion ()->blocks ())
2621
+ ScopBlocks.push_back (Block);
2622
+ // Run the data flow analysis for all values in the scop statement
2623
+ for (BasicBlock *Block : ScopBlocks) {
2624
+ for (Instruction &Inst : *Block) {
2625
+ if ((Stmt.getParent ())->getStmtFor (&Inst) != &Stmt)
2626
+ continue ;
2627
+ bool UsedOutsideStmt = any_of (Inst.users (), [&Stmt](User *U) {
2628
+ return (Stmt.getParent ())->getStmtFor (cast<Instruction>(U)) != &Stmt;
2629
+ });
2630
+ // Treat loads and stores special
2631
+ if (auto *Load = dyn_cast<LoadInst>(&Inst)) {
2632
+ // Invalidate all loads used which feed into the address of this load.
2633
+ if (auto *Ptr = dyn_cast<Instruction>(Load->getPointerOperand ())) {
2634
+ const auto &It = State.find (Ptr);
2635
+ if (It != State.end ())
2636
+ for (const auto &FlowInSetElem : It->second )
2637
+ InvalidLoads.insert (FlowInSetElem.first );
2638
+ }
2584
2639
2585
- // First collect candidate load-store reduction chains by iterating over all
2586
- // stores and collecting possible reduction loads.
2587
- for (MemoryAccess *StoreMA : Stmt) {
2588
- if (StoreMA->isRead ())
2589
- continue ;
2640
+ // If this load is used outside this stmt, invalidate it.
2641
+ if (UsedOutsideStmt)
2642
+ InvalidLoads.insert (Load);
2643
+
2644
+ // And indicate that this load uses itself once but without specifying
2645
+ // any reduction operator.
2646
+ State[Load].insert (
2647
+ std::make_pair (Load, std::make_pair (1 , MemoryAccess::RT_BOTTOM)));
2648
+ continue ;
2649
+ }
2650
+
2651
+ if (auto *Store = dyn_cast<StoreInst>(&Inst)) {
2652
+ // Invalidate all loads which feed into the address of this store.
2653
+ if (const Instruction *Ptr =
2654
+ dyn_cast<Instruction>(Store->getPointerOperand ())) {
2655
+ const auto &It = State.find (Ptr);
2656
+ if (It != State.end ())
2657
+ for (const auto &FlowInSetElem : It->second )
2658
+ InvalidLoads.insert (FlowInSetElem.first );
2659
+ }
2660
+
2661
+ // Propagate the uses of the value operand to the store
2662
+ if (auto *ValueInst = dyn_cast<Instruction>(Store->getValueOperand ()))
2663
+ State.insert (std::make_pair (Store, State[ValueInst]));
2664
+ continue ;
2665
+ }
2666
+
2667
+ // Non load and store instructions are either binary operators or they
2668
+ // will invalidate all used loads.
2669
+ auto *BinOp = dyn_cast<BinaryOperator>(&Inst);
2670
+ MemoryAccess::ReductionType CurRedType = getReductionType (BinOp);
2671
+ POLLY_DEBUG (dbgs () << " CurInst: " << Inst << " RT: " << CurRedType
2672
+ << " \n " );
2673
+
2674
+ // Iterate over all operands and propagate their input loads to
2675
+ // instruction.
2676
+ FlowInSetTy &InstInFlowSet = State[&Inst];
2677
+ for (Use &Op : Inst.operands ()) {
2678
+ auto *OpInst = dyn_cast<Instruction>(Op);
2679
+ if (!OpInst)
2680
+ continue ;
2681
+
2682
+ POLLY_DEBUG (dbgs ().indent (4 ) << " Op Inst: " << *OpInst << " \n " );
2683
+ const StateTy::iterator &OpInFlowSetIt = State.find (OpInst);
2684
+ if (OpInFlowSetIt == State.end ())
2685
+ continue ;
2686
+
2687
+ // Iterate over all the input loads of the operand and combine them
2688
+ // with the input loads of current instruction.
2689
+ FlowInSetTy &OpInFlowSet = OpInFlowSetIt->second ;
2690
+ for (auto &OpInFlowPair : OpInFlowSet) {
2691
+ unsigned OpFlowIn = OpInFlowPair.second .first ;
2692
+ unsigned InstFlowIn = InstInFlowSet[OpInFlowPair.first ].first ;
2693
+
2694
+ MemoryAccess::ReductionType OpRedType = OpInFlowPair.second .second ;
2695
+ MemoryAccess::ReductionType InstRedType =
2696
+ InstInFlowSet[OpInFlowPair.first ].second ;
2697
+
2698
+ MemoryAccess::ReductionType NewRedType =
2699
+ combineReductionType (OpRedType, CurRedType);
2700
+ if (InstFlowIn)
2701
+ NewRedType = combineReductionType (NewRedType, InstRedType);
2702
+
2703
+ POLLY_DEBUG (dbgs ().indent (8 ) << " OpRedType: " << OpRedType << " \n " );
2704
+ POLLY_DEBUG (dbgs ().indent (8 ) << " NewRedType: " << NewRedType << " \n " );
2705
+ InstInFlowSet[OpInFlowPair.first ] =
2706
+ std::make_pair (OpFlowIn + InstFlowIn, NewRedType);
2707
+ }
2708
+ }
2590
2709
2591
- Loads.clear ();
2592
- collectCandidateReductionLoads (StoreMA, Loads);
2593
- for (MemoryAccess *LoadMA : Loads)
2594
- Candidates.push_back (std::make_pair (LoadMA, StoreMA));
2710
+ // If this operation is used outside the stmt, invalidate all the loads
2711
+ // which feed into it.
2712
+ if (UsedOutsideStmt)
2713
+ for (const auto &FlowInSetElem : InstInFlowSet)
2714
+ InvalidLoads.insert (FlowInSetElem.first );
2715
+ }
2595
2716
}
2596
2717
2597
- // Then check each possible candidate pair.
2598
- for (const auto &CandidatePair : Candidates) {
2599
- MemoryAccess *LoadMA = CandidatePair.first ;
2600
- MemoryAccess *StoreMA = CandidatePair.second ;
2601
- bool Valid = checkCandidatePairAccesses (LoadMA, StoreMA, Stmt.getDomain (),
2602
- Stmt.MemAccs );
2603
- if (!Valid)
2718
+ // All used loads are propagated through the whole basic block; now try to
2719
+ // find valid reduction-like candidate pairs. These load-store pairs fulfill
2720
+ // all reduction like properties with regards to only this load-store chain.
2721
+ // We later have to check if the loaded value was invalidated by an
2722
+ // instruction not in that chain.
2723
+ using MemAccPair = std::pair<MemoryAccess *, MemoryAccess *>;
2724
+ DenseMap<MemAccPair, MemoryAccess::ReductionType> ValidCandidates;
2725
+ DominatorTree *DT = Stmt.getParent ()->getDT ();
2726
+
2727
+ // Iterate over all write memory accesses and check the loads flowing into
2728
+ // it for reduction candidate pairs.
2729
+ for (MemoryAccess *WriteMA : Stmt.MemAccs ) {
2730
+ if (WriteMA->isRead ())
2731
+ continue ;
2732
+ StoreInst *St = dyn_cast<StoreInst>(WriteMA->getAccessInstruction ());
2733
+ if (!St)
2604
2734
continue ;
2735
+ assert (!St->isVolatile ());
2736
+
2737
+ FlowInSetTy &MaInFlowSet = State[WriteMA->getAccessInstruction ()];
2738
+ for (auto &MaInFlowSetElem : MaInFlowSet) {
2739
+ MemoryAccess *ReadMA = &Stmt.getArrayAccessFor (MaInFlowSetElem.first );
2740
+ assert (ReadMA && " Couldn't find memory access for incoming load!" );
2605
2741
2606
- const LoadInst *Load =
2607
- dyn_cast<const LoadInst>(CandidatePair.first ->getAccessInstruction ());
2608
- MemoryAccess::ReductionType RT =
2609
- getReductionType (dyn_cast<BinaryOperator>(Load->user_back ()), Load);
2742
+ POLLY_DEBUG (dbgs () << " '" << *ReadMA->getAccessInstruction ()
2743
+ << " '\n\t flows into\n '"
2744
+ << *WriteMA->getAccessInstruction () << " '\n\t #"
2745
+ << MaInFlowSetElem.second .first << " times & RT: "
2746
+ << MaInFlowSetElem.second .second << " \n " );
2610
2747
2611
- // If no overlapping access was found we mark the load and store as
2612
- // reduction like.
2613
- LoadMA->markAsReductionLike (RT);
2614
- StoreMA->markAsReductionLike (RT);
2748
+ MemoryAccess::ReductionType RT = MaInFlowSetElem.second .second ;
2749
+ unsigned NumAllowableInFlow = 1 ;
2750
+
2751
+ // We allow the load to flow in exactly once for binary reductions
2752
+ bool Valid = (MaInFlowSetElem.second .first == NumAllowableInFlow);
2753
+
2754
+ // Check if we saw a valid chain of binary operators.
2755
+ Valid = Valid && RT != MemoryAccess::RT_BOTTOM;
2756
+ Valid = Valid && RT != MemoryAccess::RT_NONE;
2757
+
2758
+ // Then check if the memory accesses allow a reduction.
2759
+ Valid = Valid && checkCandidatePairAccesses (
2760
+ ReadMA, WriteMA, Stmt.getDomain (), Stmt.MemAccs );
2761
+
2762
+ // Finally, mark the pair as a candidate or the load as a invalid one.
2763
+ if (Valid)
2764
+ ValidCandidates[std::make_pair (ReadMA, WriteMA)] = RT;
2765
+ else
2766
+ InvalidLoads.insert (ReadMA->getAccessInstruction ());
2767
+ }
2768
+ }
2769
+
2770
+ // In the last step mark the memory accesses of candidate pairs as reduction
2771
+ // like if the load wasn't marked invalid in the previous step.
2772
+ for (auto &CandidatePair : ValidCandidates) {
2773
+ MemoryAccess *LoadMA = CandidatePair.first .first ;
2774
+ if (InvalidLoads.count (LoadMA->getAccessInstruction ()))
2775
+ continue ;
2776
+ POLLY_DEBUG (
2777
+ dbgs () << " Load :: "
2778
+ << *((CandidatePair.first .first )->getAccessInstruction ())
2779
+ << " \n Store :: "
2780
+ << *((CandidatePair.first .second )->getAccessInstruction ())
2781
+ << " \n are marked as reduction like\n " );
2782
+ MemoryAccess::ReductionType RT = CandidatePair.second ;
2783
+ CandidatePair.first .first ->markAsReductionLike (RT);
2784
+ CandidatePair.first .second ->markAsReductionLike (RT);
2615
2785
}
2616
2786
}
2617
2787
@@ -2965,52 +3135,6 @@ void ScopBuilder::addInvariantLoads(ScopStmt &Stmt,
2965
3135
}
2966
3136
}
2967
3137
2968
- void ScopBuilder::collectCandidateReductionLoads (
2969
- MemoryAccess *StoreMA, SmallVectorImpl<MemoryAccess *> &Loads) {
2970
- ScopStmt *Stmt = StoreMA->getStatement ();
2971
-
2972
- auto *Store = dyn_cast<StoreInst>(StoreMA->getAccessInstruction ());
2973
- if (!Store)
2974
- return ;
2975
-
2976
- // Skip if there is not one binary operator between the load and the store
2977
- auto *BinOp = dyn_cast<BinaryOperator>(Store->getValueOperand ());
2978
- if (!BinOp)
2979
- return ;
2980
-
2981
- // Skip if the binary operators has multiple uses
2982
- if (BinOp->getNumUses () != 1 )
2983
- return ;
2984
-
2985
- // Skip if the opcode of the binary operator is not commutative/associative
2986
- if (!BinOp->isCommutative () || !BinOp->isAssociative ())
2987
- return ;
2988
-
2989
- // Skip if the binary operator is outside the current SCoP
2990
- if (BinOp->getParent () != Store->getParent ())
2991
- return ;
2992
-
2993
- // Skip if it is a multiplicative reduction and we disabled them
2994
- if (DisableMultiplicativeReductions &&
2995
- (BinOp->getOpcode () == Instruction::Mul ||
2996
- BinOp->getOpcode () == Instruction::FMul))
2997
- return ;
2998
-
2999
- // Check the binary operator operands for a candidate load
3000
- auto *PossibleLoad0 = dyn_cast<LoadInst>(BinOp->getOperand (0 ));
3001
- auto *PossibleLoad1 = dyn_cast<LoadInst>(BinOp->getOperand (1 ));
3002
- if (!PossibleLoad0 && !PossibleLoad1)
3003
- return ;
3004
-
3005
- // A load is only a candidate if it cannot escape (thus has only this use)
3006
- if (PossibleLoad0 && PossibleLoad0->getNumUses () == 1 )
3007
- if (PossibleLoad0->getParent () == Store->getParent ())
3008
- Loads.push_back (&Stmt->getArrayAccessFor (PossibleLoad0));
3009
- if (PossibleLoad1 && PossibleLoad1->getNumUses () == 1 )
3010
- if (PossibleLoad1->getParent () == Store->getParent ())
3011
- Loads.push_back (&Stmt->getArrayAccessFor (PossibleLoad1));
3012
- }
3013
-
3014
3138
// / Find the canonical scop array info object for a set of invariant load
3015
3139
// / hoisted loads. The canonical array is the one that corresponds to the
3016
3140
// / first load in the list of accesses which is used as base pointer of a
0 commit comments