@@ -2674,11 +2674,242 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
2674
2674
return {DispatchAfter, DispatchAfter->getFirstInsertionPt ()};
2675
2675
}
2676
2676
2677
+ // Returns an LLVM function to call for executing an OpenMP static worksharing
2678
+ // for loop depending on `type`. Only i32 and i64 are supported by the runtime.
2679
+ // Always interpret integers as unsigned similarly to CanonicalLoopInfo.
2680
+ static FunctionCallee
2681
+ getKmpcForStaticLoopForType (Type *Ty, OpenMPIRBuilder *OMPBuilder,
2682
+ WorksharingLoopType LoopType) {
2683
+ unsigned Bitwidth = Ty->getIntegerBitWidth ();
2684
+ Module &M = OMPBuilder->M ;
2685
+ switch (LoopType) {
2686
+ case WorksharingLoopType::ForStaticLoop:
2687
+ if (Bitwidth == 32 )
2688
+ return OMPBuilder->getOrCreateRuntimeFunction (
2689
+ M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
2690
+ if (Bitwidth == 64 )
2691
+ return OMPBuilder->getOrCreateRuntimeFunction (
2692
+ M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
2693
+ break ;
2694
+ case WorksharingLoopType::DistributeStaticLoop:
2695
+ if (Bitwidth == 32 )
2696
+ return OMPBuilder->getOrCreateRuntimeFunction (
2697
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
2698
+ if (Bitwidth == 64 )
2699
+ return OMPBuilder->getOrCreateRuntimeFunction (
2700
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
2701
+ break ;
2702
+ case WorksharingLoopType::DistributeForStaticLoop:
2703
+ if (Bitwidth == 32 )
2704
+ return OMPBuilder->getOrCreateRuntimeFunction (
2705
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
2706
+ if (Bitwidth == 64 )
2707
+ return OMPBuilder->getOrCreateRuntimeFunction (
2708
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
2709
+ break ;
2710
+ }
2711
+ if (Bitwidth != 32 && Bitwidth != 64 ) {
2712
+ llvm_unreachable (" Unknown OpenMP loop iterator bitwidth" );
2713
+ }
2714
+ llvm_unreachable (" Unknown type of OpenMP worksharing loop" );
2715
+ }
2716
+
2717
+ // Inserts a call to proper OpenMP Device RTL function which handles
2718
+ // loop worksharing.
2719
+ static void createTargetLoopWorkshareCall (
2720
+ OpenMPIRBuilder *OMPBuilder, WorksharingLoopType LoopType,
2721
+ BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
2722
+ Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
2723
+ Type *TripCountTy = TripCount->getType ();
2724
+ Module &M = OMPBuilder->M ;
2725
+ IRBuilder<> &Builder = OMPBuilder->Builder ;
2726
+ FunctionCallee RTLFn =
2727
+ getKmpcForStaticLoopForType (TripCountTy, OMPBuilder, LoopType);
2728
+ SmallVector<Value *, 8 > RealArgs;
2729
+ RealArgs.push_back (Ident);
2730
+ RealArgs.push_back (Builder.CreateBitCast (&LoopBodyFn, ParallelTaskPtr));
2731
+ RealArgs.push_back (LoopBodyArg);
2732
+ RealArgs.push_back (TripCount);
2733
+ if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
2734
+ RealArgs.push_back (ConstantInt::get (TripCountTy, 0 ));
2735
+ Builder.CreateCall (RTLFn, RealArgs);
2736
+ return ;
2737
+ }
2738
+ FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction (
2739
+ M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
2740
+ Builder.restoreIP ({InsertBlock, std::prev (InsertBlock->end ())});
2741
+ Value *NumThreads = Builder.CreateCall (RTLNumThreads, {});
2742
+
2743
+ RealArgs.push_back (
2744
+ Builder.CreateZExtOrTrunc (NumThreads, TripCountTy, " num.threads.cast" ));
2745
+ RealArgs.push_back (ConstantInt::get (TripCountTy, 0 ));
2746
+ if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
2747
+ RealArgs.push_back (ConstantInt::get (TripCountTy, 0 ));
2748
+ }
2749
+
2750
+ Builder.CreateCall (RTLFn, RealArgs);
2751
+ }
2752
+
2753
+ static void
2754
+ workshareLoopTargetCallback (OpenMPIRBuilder *OMPIRBuilder,
2755
+ CanonicalLoopInfo *CLI, Value *Ident,
2756
+ Function &OutlinedFn, Type *ParallelTaskPtr,
2757
+ const SmallVector<Instruction *, 4 > &ToBeDeleted,
2758
+ WorksharingLoopType LoopType) {
2759
+ IRBuilder<> &Builder = OMPIRBuilder->Builder ;
2760
+ BasicBlock *Preheader = CLI->getPreheader ();
2761
+ Value *TripCount = CLI->getTripCount ();
2762
+
2763
+ // After loop body outling, the loop body contains only set up
2764
+ // of loop body argument structure and the call to the outlined
2765
+ // loop body function. Firstly, we need to move setup of loop body args
2766
+ // into loop preheader.
2767
+ Preheader->splice (std::prev (Preheader->end ()), CLI->getBody (),
2768
+ CLI->getBody ()->begin (), std::prev (CLI->getBody ()->end ()));
2769
+
2770
+ // The next step is to remove the whole loop. We do not it need anymore.
2771
+ // That's why make an unconditional branch from loop preheader to loop
2772
+ // exit block
2773
+ Builder.restoreIP ({Preheader, Preheader->end ()});
2774
+ Preheader->getTerminator ()->eraseFromParent ();
2775
+ Builder.CreateBr (CLI->getExit ());
2776
+
2777
+ // Delete dead loop blocks
2778
+ OpenMPIRBuilder::OutlineInfo CleanUpInfo;
2779
+ SmallPtrSet<BasicBlock *, 32 > RegionBlockSet;
2780
+ SmallVector<BasicBlock *, 32 > BlocksToBeRemoved;
2781
+ CleanUpInfo.EntryBB = CLI->getHeader ();
2782
+ CleanUpInfo.ExitBB = CLI->getExit ();
2783
+ CleanUpInfo.collectBlocks (RegionBlockSet, BlocksToBeRemoved);
2784
+ DeleteDeadBlocks (BlocksToBeRemoved);
2785
+
2786
+ // Find the instruction which corresponds to loop body argument structure
2787
+ // and remove the call to loop body function instruction.
2788
+ Value *LoopBodyArg;
2789
+ User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser ();
2790
+ assert (OutlinedFnUser &&
2791
+ " Expected unique undroppable user of outlined function" );
2792
+ CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(OutlinedFnUser);
2793
+ assert (OutlinedFnCallInstruction && " Expected outlined function call" );
2794
+ assert ((OutlinedFnCallInstruction->getParent () == Preheader) &&
2795
+ " Expected outlined function call to be located in loop preheader" );
2796
+ // Check in case no argument structure has been passed.
2797
+ if (OutlinedFnCallInstruction->arg_size () > 1 )
2798
+ LoopBodyArg = OutlinedFnCallInstruction->getArgOperand (1 );
2799
+ else
2800
+ LoopBodyArg = Constant::getNullValue (Builder.getPtrTy ());
2801
+ OutlinedFnCallInstruction->eraseFromParent ();
2802
+
2803
+ createTargetLoopWorkshareCall (OMPIRBuilder, LoopType, Preheader, Ident,
2804
+ LoopBodyArg, ParallelTaskPtr, TripCount,
2805
+ OutlinedFn);
2806
+
2807
+ for (auto &ToBeDeletedItem : ToBeDeleted)
2808
+ ToBeDeletedItem->eraseFromParent ();
2809
+ CLI->invalidate ();
2810
+ }
2811
+
2812
+ OpenMPIRBuilder::InsertPointTy
2813
+ OpenMPIRBuilder::applyWorkshareLoopTarget (DebugLoc DL, CanonicalLoopInfo *CLI,
2814
+ InsertPointTy AllocaIP,
2815
+ WorksharingLoopType LoopType) {
2816
+ uint32_t SrcLocStrSize;
2817
+ Constant *SrcLocStr = getOrCreateSrcLocStr (DL, SrcLocStrSize);
2818
+ Value *Ident = getOrCreateIdent (SrcLocStr, SrcLocStrSize);
2819
+
2820
+ OutlineInfo OI;
2821
+ OI.OuterAllocaBB = CLI->getPreheader ();
2822
+ Function *OuterFn = CLI->getPreheader ()->getParent ();
2823
+
2824
+ // Instructions which need to be deleted at the end of code generation
2825
+ SmallVector<Instruction *, 4 > ToBeDeleted;
2826
+
2827
+ OI.OuterAllocaBB = AllocaIP.getBlock ();
2828
+
2829
+ // Mark the body loop as region which needs to be extracted
2830
+ OI.EntryBB = CLI->getBody ();
2831
+ OI.ExitBB = CLI->getLatch ()->splitBasicBlock (CLI->getLatch ()->begin (),
2832
+ " omp.prelatch" , true );
2833
+
2834
+ // Prepare loop body for extraction
2835
+ Builder.restoreIP ({CLI->getPreheader (), CLI->getPreheader ()->begin ()});
2836
+
2837
+ // Insert new loop counter variable which will be used only in loop
2838
+ // body.
2839
+ AllocaInst *NewLoopCnt = Builder.CreateAlloca (CLI->getIndVarType (), 0 , " " );
2840
+ Instruction *NewLoopCntLoad =
2841
+ Builder.CreateLoad (CLI->getIndVarType (), NewLoopCnt);
2842
+ // New loop counter instructions are redundant in the loop preheader when
2843
+ // code generation for workshare loop is finshed. That's why mark them as
2844
+ // ready for deletion.
2845
+ ToBeDeleted.push_back (NewLoopCntLoad);
2846
+ ToBeDeleted.push_back (NewLoopCnt);
2847
+
2848
+ // Analyse loop body region. Find all input variables which are used inside
2849
+ // loop body region.
2850
+ SmallPtrSet<BasicBlock *, 32 > ParallelRegionBlockSet;
2851
+ SmallVector<BasicBlock *, 32 > Blocks;
2852
+ OI.collectBlocks (ParallelRegionBlockSet, Blocks);
2853
+ SmallVector<BasicBlock *, 32 > BlocksT (ParallelRegionBlockSet.begin (),
2854
+ ParallelRegionBlockSet.end ());
2855
+
2856
+ CodeExtractorAnalysisCache CEAC (*OuterFn);
2857
+ CodeExtractor Extractor (Blocks,
2858
+ /* DominatorTree */ nullptr ,
2859
+ /* AggregateArgs */ true ,
2860
+ /* BlockFrequencyInfo */ nullptr ,
2861
+ /* BranchProbabilityInfo */ nullptr ,
2862
+ /* AssumptionCache */ nullptr ,
2863
+ /* AllowVarArgs */ true ,
2864
+ /* AllowAlloca */ true ,
2865
+ /* AllocationBlock */ CLI->getPreheader (),
2866
+ /* Suffix */ " .omp_wsloop" ,
2867
+ /* AggrArgsIn0AddrSpace */ true );
2868
+
2869
+ BasicBlock *CommonExit = nullptr ;
2870
+ SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
2871
+
2872
+ // Find allocas outside the loop body region which are used inside loop
2873
+ // body
2874
+ Extractor.findAllocas (CEAC, SinkingCands, HoistingCands, CommonExit);
2875
+
2876
+ // We need to model loop body region as the function f(cnt, loop_arg).
2877
+ // That's why we replace loop induction variable by the new counter
2878
+ // which will be one of loop body function argument
2879
+ for (auto Use = CLI->getIndVar ()->user_begin ();
2880
+ Use != CLI->getIndVar ()->user_end (); ++Use) {
2881
+ if (Instruction *Inst = dyn_cast<Instruction>(*Use)) {
2882
+ if (ParallelRegionBlockSet.count (Inst->getParent ())) {
2883
+ Inst->replaceUsesOfWith (CLI->getIndVar (), NewLoopCntLoad);
2884
+ }
2885
+ }
2886
+ }
2887
+ // Make sure that loop counter variable is not merged into loop body
2888
+ // function argument structure and it is passed as separate variable
2889
+ OI.ExcludeArgsFromAggregate .push_back (NewLoopCntLoad);
2890
+
2891
+ // PostOutline CB is invoked when loop body function is outlined and
2892
+ // loop body is replaced by call to outlined function. We need to add
2893
+ // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
2894
+ // function will handle loop control logic.
2895
+ //
2896
+ OI.PostOutlineCB = [=, ToBeDeletedVec =
2897
+ std::move (ToBeDeleted)](Function &OutlinedFn) {
2898
+ workshareLoopTargetCallback (this , CLI, Ident, OutlinedFn, ParallelTaskPtr,
2899
+ ToBeDeletedVec, LoopType);
2900
+ };
2901
+ addOutlineInfo (std::move (OI));
2902
+ return CLI->getAfterIP ();
2903
+ }
2904
+
2677
2905
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop (
2678
2906
DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2679
- bool NeedsBarrier, llvm::omp::ScheduleKind SchedKind,
2680
- llvm::Value *ChunkSize, bool HasSimdModifier, bool HasMonotonicModifier,
2681
- bool HasNonmonotonicModifier, bool HasOrderedClause) {
2907
+ bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
2908
+ bool HasSimdModifier, bool HasMonotonicModifier,
2909
+ bool HasNonmonotonicModifier, bool HasOrderedClause,
2910
+ WorksharingLoopType LoopType) {
2911
+ if (Config.isTargetDevice ())
2912
+ return applyWorkshareLoopTarget (DL, CLI, AllocaIP, LoopType);
2682
2913
OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType (
2683
2914
SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
2684
2915
HasNonmonotonicModifier, HasOrderedClause);
0 commit comments