@@ -45,6 +45,9 @@ struct DynamicScheduleTracker {
45
45
46
46
#pragma omp begin declare target device_type(nohost)
47
47
48
+ extern int32_t __omp_rtl_assume_teams_oversubscription;
49
+ extern int32_t __omp_rtl_assume_threads_oversubscription;
50
+
48
51
// TODO: This variable is a hack inherited from the old runtime.
49
52
static uint64_t SHARED (Cnt);
50
53
@@ -636,4 +639,255 @@ void __kmpc_for_static_fini(IdentTy *loc, int32_t global_tid) {}
636
639
void __kmpc_distribute_static_fini (IdentTy *loc, int32_t global_tid) {}
637
640
}
638
641
642
+ namespace ompx {
643
+
644
+ // / Helper class to hide the generic loop nest and provide the template argument
645
+ // / throughout.
646
+ template <typename Ty> class StaticLoopChunker {
647
+
648
+ // / Generic loop nest that handles block and/or thread distribution in the
649
+ // / absence of user specified chunk sizes. This implicitly picks a block chunk
650
+ // / size equal to the number of threads in the block and a thread chunk size
651
+ // / equal to one. In contrast to the chunked version we can get away with a
652
+ // / single loop in this case
653
+ static void NormalizedLoopNestNoChunk (void (*LoopBody)(Ty, void *), void *Arg,
654
+ Ty NumBlocks, Ty BId, Ty NumThreads,
655
+ Ty TId, Ty NumIters,
656
+ bool OneIterationPerThread) {
657
+ Ty KernelIteration = NumBlocks * NumThreads;
658
+
659
+ // Start index in the normalized space.
660
+ Ty IV = BId * NumThreads + TId;
661
+ ASSERT (IV >= 0 , " Bad index" );
662
+
663
+ // Cover the entire iteration space, assumptions in the caller might allow
664
+ // to simplify this loop to a conditional.
665
+ if (IV < NumIters) {
666
+ do {
667
+
668
+ // Execute the loop body.
669
+ LoopBody (IV, Arg);
670
+
671
+ // Every thread executed one block and thread chunk now.
672
+ IV += KernelIteration;
673
+
674
+ if (OneIterationPerThread)
675
+ return ;
676
+
677
+ } while (IV < NumIters);
678
+ }
679
+ }
680
+
681
+ // / Generic loop nest that handles block and/or thread distribution in the
682
+ // / presence of user specified chunk sizes (for at least one of them).
683
+ static void NormalizedLoopNestChunked (void (*LoopBody)(Ty, void *), void *Arg,
684
+ Ty BlockChunk, Ty NumBlocks, Ty BId,
685
+ Ty ThreadChunk, Ty NumThreads, Ty TId,
686
+ Ty NumIters,
687
+ bool OneIterationPerThread) {
688
+ Ty KernelIteration = NumBlocks * BlockChunk;
689
+
690
+ // Start index in the chunked space.
691
+ Ty IV = BId * BlockChunk + TId;
692
+ ASSERT (IV >= 0 , " Bad index" );
693
+
694
+ // Cover the entire iteration space, assumptions in the caller might allow
695
+ // to simplify this loop to a conditional.
696
+ do {
697
+
698
+ Ty BlockChunkLeft =
699
+ BlockChunk >= TId * ThreadChunk ? BlockChunk - TId * ThreadChunk : 0 ;
700
+ Ty ThreadChunkLeft =
701
+ ThreadChunk <= BlockChunkLeft ? ThreadChunk : BlockChunkLeft;
702
+
703
+ while (ThreadChunkLeft--) {
704
+
705
+ // Given the blocking it's hard to keep track of what to execute.
706
+ if (IV >= NumIters)
707
+ return ;
708
+
709
+ // Execute the loop body.
710
+ LoopBody (IV, Arg);
711
+
712
+ if (OneIterationPerThread)
713
+ return ;
714
+
715
+ ++IV;
716
+ }
717
+
718
+ IV += KernelIteration;
719
+
720
+ } while (IV < NumIters);
721
+ }
722
+
723
+ public:
724
+ // / Worksharing `for`-loop.
725
+ static void For (IdentTy *Loc, void (*LoopBody)(Ty, void *), void *Arg,
726
+ Ty NumIters, Ty NumThreads, Ty ThreadChunk) {
727
+ ASSERT (NumIters >= 0 , " Bad iteration count" );
728
+ ASSERT (ThreadChunk >= 0 , " Bad thread count" );
729
+
730
+ // All threads need to participate but we don't know if we are in a
731
+ // parallel at all or if the user might have used a `num_threads` clause
732
+ // on the parallel and reduced the number compared to the block size.
733
+ // Since nested parallels are possible too we need to get the thread id
734
+ // from the `omp` getter and not the mapping directly.
735
+ Ty TId = omp_get_thread_num ();
736
+
737
+ // There are no blocks involved here.
738
+ Ty BlockChunk = 0 ;
739
+ Ty NumBlocks = 1 ;
740
+ Ty BId = 0 ;
741
+
742
+ // If the thread chunk is not specified we pick a default now.
743
+ if (ThreadChunk == 0 )
744
+ ThreadChunk = 1 ;
745
+
746
+ // If we know we have more threads than iterations we can indicate that to
747
+ // avoid an outer loop.
748
+ bool OneIterationPerThread = false ;
749
+ if (__omp_rtl_assume_threads_oversubscription) {
750
+ ASSERT (NumThreads >= NumIters, " Broken assumption" );
751
+ OneIterationPerThread = true ;
752
+ }
753
+
754
+ if (ThreadChunk != 1 )
755
+ NormalizedLoopNestChunked (LoopBody, Arg, BlockChunk, NumBlocks, BId,
756
+ ThreadChunk, NumThreads, TId, NumIters,
757
+ OneIterationPerThread);
758
+ else
759
+ NormalizedLoopNestNoChunk (LoopBody, Arg, NumBlocks, BId, NumThreads, TId,
760
+ NumIters, OneIterationPerThread);
761
+ }
762
+
763
+ // / Worksharing `distrbute`-loop.
764
+ static void Distribute (IdentTy *Loc, void (*LoopBody)(Ty, void *), void *Arg,
765
+ Ty NumIters, Ty BlockChunk) {
766
+ ASSERT (icv::Level == 0 , " Bad distribute" );
767
+ ASSERT (icv::ActiveLevel == 0 , " Bad distribute" );
768
+ ASSERT (state::ParallelRegionFn == nullptr , " Bad distribute" );
769
+ ASSERT (state::ParallelTeamSize == 1 , " Bad distribute" );
770
+
771
+ ASSERT (NumIters >= 0 , " Bad iteration count" );
772
+ ASSERT (BlockChunk >= 0 , " Bad block count" );
773
+
774
+ // There are no threads involved here.
775
+ Ty ThreadChunk = 0 ;
776
+ Ty NumThreads = 1 ;
777
+ Ty TId = 0 ;
778
+ ASSERT (TId == mapping::getThreadIdInBlock (), " Bad thread id" );
779
+
780
+ // All teams need to participate.
781
+ Ty NumBlocks = mapping::getNumberOfBlocksInKernel ();
782
+ Ty BId = mapping::getBlockIdInKernel ();
783
+
784
+ // If the block chunk is not specified we pick a default now.
785
+ if (BlockChunk == 0 )
786
+ BlockChunk = NumThreads;
787
+
788
+ // If we know we have more blocks than iterations we can indicate that to
789
+ // avoid an outer loop.
790
+ bool OneIterationPerThread = false ;
791
+ if (__omp_rtl_assume_teams_oversubscription) {
792
+ ASSERT (NumBlocks >= NumIters, " Broken assumption" );
793
+ OneIterationPerThread = true ;
794
+ }
795
+
796
+ if (BlockChunk != NumThreads)
797
+ NormalizedLoopNestChunked (LoopBody, Arg, BlockChunk, NumBlocks, BId,
798
+ ThreadChunk, NumThreads, TId, NumIters,
799
+ OneIterationPerThread);
800
+ else
801
+ NormalizedLoopNestNoChunk (LoopBody, Arg, NumBlocks, BId, NumThreads, TId,
802
+ NumIters, OneIterationPerThread);
803
+
804
+ ASSERT (icv::Level == 0 , " Bad distribute" );
805
+ ASSERT (icv::ActiveLevel == 0 , " Bad distribute" );
806
+ ASSERT (state::ParallelRegionFn == nullptr , " Bad distribute" );
807
+ ASSERT (state::ParallelTeamSize == 1 , " Bad distribute" );
808
+ }
809
+
810
+ // / Worksharing `distrbute parallel for`-loop.
811
+ static void DistributeFor (IdentTy *Loc, void (*LoopBody)(Ty, void *),
812
+ void *Arg, Ty NumIters, Ty NumThreads,
813
+ Ty BlockChunk, Ty ThreadChunk) {
814
+ ASSERT (icv::Level == 1 , " Bad distribute" );
815
+ ASSERT (icv::ActiveLevel == 1 , " Bad distribute" );
816
+ ASSERT (state::ParallelRegionFn == nullptr , " Bad distribute" );
817
+
818
+ ASSERT (NumIters >= 0 , " Bad iteration count" );
819
+ ASSERT (BlockChunk >= 0 , " Bad block count" );
820
+ ASSERT (ThreadChunk >= 0 , " Bad thread count" );
821
+
822
+ // All threads need to participate but the user might have used a
823
+ // `num_threads` clause on the parallel and reduced the number compared to
824
+ // the block size.
825
+ Ty TId = mapping::getThreadIdInBlock ();
826
+
827
+ // All teams need to participate.
828
+ Ty NumBlocks = mapping::getNumberOfBlocksInKernel ();
829
+ Ty BId = mapping::getBlockIdInKernel ();
830
+
831
+ // If the block chunk is not specified we pick a default now.
832
+ if (BlockChunk == 0 )
833
+ BlockChunk = NumThreads;
834
+
835
+ // If the thread chunk is not specified we pick a default now.
836
+ if (ThreadChunk == 0 )
837
+ ThreadChunk = 1 ;
838
+
839
+ // If we know we have more threads (across all blocks) than iterations we
840
+ // can indicate that to avoid an outer loop.
841
+ bool OneIterationPerThread = false ;
842
+ if (__omp_rtl_assume_teams_oversubscription &
843
+ __omp_rtl_assume_threads_oversubscription) {
844
+ OneIterationPerThread = true ;
845
+ ASSERT (NumBlocks * NumThreads >= NumIters, " Broken assumption" );
846
+ }
847
+
848
+ if (BlockChunk != NumThreads || ThreadChunk != 1 )
849
+ NormalizedLoopNestChunked (LoopBody, Arg, BlockChunk, NumBlocks, BId,
850
+ ThreadChunk, NumThreads, TId, NumIters,
851
+ OneIterationPerThread);
852
+ else
853
+ NormalizedLoopNestNoChunk (LoopBody, Arg, NumBlocks, BId, NumThreads, TId,
854
+ NumIters, OneIterationPerThread);
855
+
856
+ ASSERT (icv::Level == 1 , " Bad distribute" );
857
+ ASSERT (icv::ActiveLevel == 1 , " Bad distribute" );
858
+ ASSERT (state::ParallelRegionFn == nullptr , " Bad distribute" );
859
+ }
860
+ };
861
+
862
+ } // namespace ompx
863
+
864
+ #define OMP_LOOP_ENTRY (BW, TY ) \
865
+ [[gnu::flatten, clang::always_inline]] void \
866
+ __kmpc_distribute_for_static_loop##BW( \
867
+ IdentTy *loc, void (*fn)(TY, void *), void *arg, TY num_iters, \
868
+ TY num_threads, TY block_chunk, TY thread_chunk) { \
869
+ ompx::StaticLoopChunker<TY>::DistributeFor ( \
870
+ loc, fn, arg, num_iters + 1 , num_threads, block_chunk, thread_chunk); \
871
+ } \
872
+ [[gnu::flatten, clang::always_inline]] void \
873
+ __kmpc_distribute_static_loop##BW(IdentTy *loc, void (*fn)(TY, void *), \
874
+ void *arg, TY num_iters, \
875
+ TY block_chunk) { \
876
+ ompx::StaticLoopChunker<TY>::Distribute (loc, fn, arg, num_iters + 1 , \
877
+ block_chunk); \
878
+ } \
879
+ [[gnu::flatten, clang::always_inline]] void __kmpc_for_static_loop##BW( \
880
+ IdentTy *loc, void (*fn)(TY, void *), void *arg, TY num_iters, \
881
+ TY num_threads, TY thread_chunk) { \
882
+ ompx::StaticLoopChunker<TY>::For (loc, fn, arg, num_iters + 1 , num_threads, \
883
+ thread_chunk); \
884
+ }
885
+
886
+ extern " C" {
887
+ OMP_LOOP_ENTRY (_4, int32_t )
888
+ OMP_LOOP_ENTRY(_4u, uint32_t )
889
+ OMP_LOOP_ENTRY(_8, int64_t )
890
+ OMP_LOOP_ENTRY(_8u, uint64_t )
891
+ }
892
+
639
893
#pragma omp end declare target
0 commit comments