@@ -2674,11 +2674,255 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
2674
2674
return {DispatchAfter, DispatchAfter->getFirstInsertionPt ()};
2675
2675
}
2676
2676
2677
+ // Returns an LLVM function to call for executing an OpenMP static worksharing
2678
+ // for loop depending on `type`. Only i32 and i64 are supported by the runtime.
2679
+ // Always interpret integers as unsigned similarly to CanonicalLoopInfo.
2680
+ static FunctionCallee
2681
+ getKmpcForStaticLoopForType (Type *Ty, OpenMPIRBuilder *OMPBuilder,
2682
+ OpenMPIRBuilder::WorksharingLoopType LoopType) {
2683
+ unsigned Bitwidth = Ty->getIntegerBitWidth ();
2684
+ Module &M = OMPBuilder->M ;
2685
+ switch (LoopType) {
2686
+ case OpenMPIRBuilder::WorksharingLoopType::ForStaticLoop:
2687
+ if (Bitwidth == 32 )
2688
+ return OMPBuilder->getOrCreateRuntimeFunction (
2689
+ M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
2690
+ if (Bitwidth == 64 )
2691
+ return OMPBuilder->getOrCreateRuntimeFunction (
2692
+ M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
2693
+ break ;
2694
+ case OpenMPIRBuilder::WorksharingLoopType::DistributeStaticLoop:
2695
+ if (Bitwidth == 32 )
2696
+ return OMPBuilder->getOrCreateRuntimeFunction (
2697
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
2698
+ if (Bitwidth == 64 )
2699
+ return OMPBuilder->getOrCreateRuntimeFunction (
2700
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
2701
+ break ;
2702
+ case OpenMPIRBuilder::WorksharingLoopType::DistributeForStaticLoop:
2703
+ if (Bitwidth == 32 )
2704
+ return OMPBuilder->getOrCreateRuntimeFunction (
2705
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
2706
+ if (Bitwidth == 64 )
2707
+ return OMPBuilder->getOrCreateRuntimeFunction (
2708
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
2709
+ break ;
2710
+ }
2711
+ if (Bitwidth != 32 && Bitwidth != 64 )
2712
+ llvm_unreachable (" unknown OpenMP loop iterator bitwidth" );
2713
+ return FunctionCallee ();
2714
+ }
2715
+
2716
+ // Inserts a call to proper OpenMP Device RTL function which handles
2717
+ // loop worksharing.
2718
+ static void createTargetLoopWorkshareCall (
2719
+ OpenMPIRBuilder *OMPBuilder, OpenMPIRBuilder::WorksharingLoopType LoopType,
2720
+ BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
2721
+ Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
2722
+ Type *TripCountTy = TripCount->getType ();
2723
+ Module &M = OMPBuilder->M ;
2724
+ IRBuilder<> &Builder = OMPBuilder->Builder ;
2725
+ FunctionCallee RTLFn =
2726
+ getKmpcForStaticLoopForType (TripCountTy, OMPBuilder, LoopType);
2727
+ SmallVector<Value *, 8 > RealArgs;
2728
+ RealArgs.push_back (Ident);
2729
+ /* loop body func*/
2730
+ RealArgs.push_back (Builder.CreateBitCast (&LoopBodyFn, ParallelTaskPtr));
2731
+ /* loop body args*/
2732
+ RealArgs.push_back (LoopBodyArg);
2733
+ /* num of iters*/
2734
+ RealArgs.push_back (TripCount);
2735
+ if (LoopType == OpenMPIRBuilder::WorksharingLoopType::DistributeStaticLoop) {
2736
+ /* block chunk*/ RealArgs.push_back (TripCountTy->getIntegerBitWidth () == 32
2737
+ ? Builder.getInt32 (0 )
2738
+ : Builder.getInt64 (0 ));
2739
+ Builder.CreateCall (RTLFn, RealArgs);
2740
+ return ;
2741
+ }
2742
+ FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction (
2743
+ M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
2744
+ Builder.restoreIP ({InsertBlock, std::prev (InsertBlock->end ())});
2745
+ Value *NumThreads = Builder.CreateCall (RTLNumThreads, {});
2746
+
2747
+ /* num of threads*/ RealArgs.push_back (
2748
+ Builder.CreateZExtOrTrunc (NumThreads, TripCountTy, " num.threads.cast" ));
2749
+ if (LoopType ==
2750
+ OpenMPIRBuilder::WorksharingLoopType::DistributeForStaticLoop) {
2751
+ /* block chunk*/ RealArgs.push_back (TripCountTy->getIntegerBitWidth () == 32
2752
+ ? Builder.getInt32 (0 )
2753
+ : Builder.getInt64 (0 ));
2754
+ }
2755
+ /* thread chunk */ RealArgs.push_back (TripCountTy->getIntegerBitWidth () == 32
2756
+ ? Builder.getInt32 (1 )
2757
+ : Builder.getInt64 (1 ));
2758
+
2759
+ Builder.CreateCall (RTLFn, RealArgs);
2760
+ }
2761
+
2762
+ static void
2763
+ workshareLoopTargetCallback (OpenMPIRBuilder *OMPIRBuilder,
2764
+ CanonicalLoopInfo *CLI, Value *Ident,
2765
+ Function &OutlinedFn, Type *ParallelTaskPtr,
2766
+ const SmallVector<Instruction *, 4 > &ToBeDeleted,
2767
+ OpenMPIRBuilder::WorksharingLoopType LoopType) {
2768
+ IRBuilder<> &Builder = OMPIRBuilder->Builder ;
2769
+ BasicBlock *Preheader = CLI->getPreheader ();
2770
+ Value *TripCount = CLI->getTripCount ();
2771
+
2772
+ // After loop body outling, the loop body contains only set up
2773
+ // of loop body argument structure and the call to the outlined
2774
+ // loop body function. Firstly, we need to move setup of loop body args
2775
+ // into loop preheader.
2776
+ Preheader->splice (std::prev (Preheader->end ()), CLI->getBody (),
2777
+ CLI->getBody ()->begin (), std::prev (CLI->getBody ()->end ()));
2778
+
2779
+ // The next step is to remove the whole loop. We do not it need anymore.
2780
+ // That's why make an unconditional branch from loop preheader to loop
2781
+ // exit block
2782
+ Builder.restoreIP ({Preheader, Preheader->end ()});
2783
+ Preheader->getTerminator ()->eraseFromParent ();
2784
+ Builder.CreateBr (CLI->getExit ());
2785
+
2786
+ // Delete dead loop blocks
2787
+ OpenMPIRBuilder::OutlineInfo CleanUpInfo;
2788
+ SmallPtrSet<BasicBlock *, 32 > RegionBlockSet;
2789
+ SmallVector<BasicBlock *, 32 > BlocksToBeRemoved;
2790
+ CleanUpInfo.EntryBB = CLI->getHeader ();
2791
+ CleanUpInfo.ExitBB = CLI->getExit ();
2792
+ CleanUpInfo.collectBlocks (RegionBlockSet, BlocksToBeRemoved);
2793
+ DeleteDeadBlocks (BlocksToBeRemoved);
2794
+
2795
+ // Find the instruction which corresponds to loop body argument structure
2796
+ // and remove the call to loop body function instruction.
2797
+ Value *LoopBodyArg;
2798
+ for (auto instIt = Preheader->begin (); instIt != Preheader->end (); ++instIt) {
2799
+ if (CallInst *CallInstruction = dyn_cast<CallInst>(instIt)) {
2800
+ if (CallInstruction->getCalledFunction () == &OutlinedFn) {
2801
+ // Check in case no argument structure has been passed.
2802
+ if (CallInstruction->arg_size () > 1 )
2803
+ LoopBodyArg = CallInstruction->getArgOperand (1 );
2804
+ else
2805
+ LoopBodyArg = Constant::getNullValue (Builder.getPtrTy ());
2806
+ CallInstruction->eraseFromParent ();
2807
+ break ;
2808
+ }
2809
+ }
2810
+ }
2811
+
2812
+ createTargetLoopWorkshareCall (OMPIRBuilder, LoopType, Preheader, Ident,
2813
+ LoopBodyArg, ParallelTaskPtr, TripCount,
2814
+ OutlinedFn);
2815
+
2816
+ for (auto &ToBeDeletedItem : ToBeDeleted)
2817
+ ToBeDeletedItem->eraseFromParent ();
2818
+ CLI->invalidate ();
2819
+ }
2820
+
2821
+ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget (
2822
+ DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2823
+ OpenMPIRBuilder::WorksharingLoopType LoopType) {
2824
+ uint32_t SrcLocStrSize;
2825
+ Constant *SrcLocStr = getOrCreateSrcLocStr (DL, SrcLocStrSize);
2826
+ Value *Ident = getOrCreateIdent (SrcLocStr, SrcLocStrSize);
2827
+
2828
+ OutlineInfo OI;
2829
+ OI.OuterAllocaBB = CLI->getPreheader ();
2830
+ Function *OuterFn = CLI->getPreheader ()->getParent ();
2831
+
2832
+ // Instructions which need to be deleted at the end of code generation
2833
+ SmallVector<Instruction *, 4 > ToBeDeleted;
2834
+
2835
+ OI.OuterAllocaBB = AllocaIP.getBlock ();
2836
+
2837
+ // Mark the body loop as region which needs to be extracted
2838
+ OI.EntryBB = CLI->getBody ();
2839
+ OI.ExitBB = CLI->getLatch ()->splitBasicBlock (CLI->getLatch ()->begin (),
2840
+ " omp.prelatch" , true );
2841
+
2842
+ // Prepare loop body for extraction
2843
+ Builder.restoreIP ({CLI->getPreheader (), CLI->getPreheader ()->begin ()});
2844
+
2845
+ // Insert new loop counter variable which will be used only in loop
2846
+ // body.
2847
+ AllocaInst *newLoopCnt = Builder.CreateAlloca (CLI->getIndVarType (), 0 , " " );
2848
+ Instruction *newLoopCntLoad =
2849
+ Builder.CreateLoad (CLI->getIndVarType (), newLoopCnt);
2850
+ // New loop counter instructions are redundant in the loop preheader when
2851
+ // code generation for workshare loop is finshed. That's why mark them as
2852
+ // ready for deletion.
2853
+ ToBeDeleted.push_back (newLoopCntLoad);
2854
+ ToBeDeleted.push_back (newLoopCnt);
2855
+
2856
+ // Analyse loop body region. Find all input variables which are used inside
2857
+ // loop body region.
2858
+ SmallPtrSet<BasicBlock *, 32 > ParallelRegionBlockSet;
2859
+ SmallVector<BasicBlock *, 32 > Blocks;
2860
+ OI.collectBlocks (ParallelRegionBlockSet, Blocks);
2861
+ SmallVector<BasicBlock *, 32 > BlocksT (ParallelRegionBlockSet.begin (),
2862
+ ParallelRegionBlockSet.end ());
2863
+
2864
+ CodeExtractorAnalysisCache CEAC (*OuterFn);
2865
+ CodeExtractor Extractor (Blocks,
2866
+ /* DominatorTree */ nullptr ,
2867
+ /* AggregateArgs */ true ,
2868
+ /* BlockFrequencyInfo */ nullptr ,
2869
+ /* BranchProbabilityInfo */ nullptr ,
2870
+ /* AssumptionCache */ nullptr ,
2871
+ /* AllowVarArgs */ true ,
2872
+ /* AllowAlloca */ true ,
2873
+ /* AllocationBlock */ CLI->getPreheader (),
2874
+ /* Suffix */ " .omp_wsloop" ,
2875
+ /* AggrArgsIn0AddrSpace */ true );
2876
+
2877
+ BasicBlock *CommonExit = nullptr ;
2878
+ SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
2879
+
2880
+ // Find allocas outside the loop body region which are used inside loop
2881
+ // body
2882
+ Extractor.findAllocas (CEAC, SinkingCands, HoistingCands, CommonExit);
2883
+
2884
+ // We need to model loop body region as the function f(cnt, loop_arg).
2885
+ // That's why we replace loop induction variable by the new counter
2886
+ // which will be one of loop body function argument
2887
+ std::vector<User *> Users (CLI->getIndVar ()->user_begin (),
2888
+ CLI->getIndVar ()->user_end ());
2889
+ for (User *use : Users) {
2890
+ if (Instruction *inst = dyn_cast<Instruction>(use)) {
2891
+ if (ParallelRegionBlockSet.count (inst->getParent ())) {
2892
+ inst->replaceUsesOfWith (CLI->getIndVar (), newLoopCntLoad);
2893
+ }
2894
+ }
2895
+ }
2896
+ Extractor.findInputsOutputs (Inputs, Outputs, SinkingCands);
2897
+ for (Value *Input : Inputs) {
2898
+ // Make sure that loop counter variable is not merged into loop body
2899
+ // function argument structure and it is passed as separate variable
2900
+ if (Input == newLoopCntLoad)
2901
+ OI.ExcludeArgsFromAggregate .push_back (Input);
2902
+ }
2903
+
2904
+ // PostOutline CB is invoked when loop body function is outlined and
2905
+ // loop body is replaced by call to outlined function. We need to add
2906
+ // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
2907
+ // function will handle loop control logic.
2908
+ //
2909
+ OI.PostOutlineCB = [=, ToBeDeletedVec =
2910
+ std::move (ToBeDeleted)](Function &OutlinedFn) {
2911
+ workshareLoopTargetCallback (this , CLI, Ident, OutlinedFn, ParallelTaskPtr,
2912
+ ToBeDeletedVec, LoopType);
2913
+ };
2914
+ addOutlineInfo (std::move (OI));
2915
+ return CLI->getAfterIP ();
2916
+ }
2917
+
2677
2918
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop (
2678
2919
DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2679
2920
bool NeedsBarrier, llvm::omp::ScheduleKind SchedKind,
2680
2921
llvm::Value *ChunkSize, bool HasSimdModifier, bool HasMonotonicModifier,
2681
- bool HasNonmonotonicModifier, bool HasOrderedClause) {
2922
+ bool HasNonmonotonicModifier, bool HasOrderedClause,
2923
+ OpenMPIRBuilder::WorksharingLoopType LoopType) {
2924
+ if (Config.isTargetDevice ())
2925
+ return applyWorkshareLoopTarget (DL, CLI, AllocaIP, LoopType);
2682
2926
OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType (
2683
2927
SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
2684
2928
HasNonmonotonicModifier, HasOrderedClause);
0 commit comments