@@ -700,6 +700,139 @@ struct AMDGPUKernelTy : public GenericKernelTy {
700
700
std::optional<utils::KernelMetaDataTy> KernelInfo;
701
701
// / CodeGen generate WGSize
702
702
uint16_t ConstWGSize;
703
+
704
+ // / Get the number of threads and blocks for the kernel based on the
705
+ // / user-defined threads and block clauses.
706
+ uint32_t getNumThreads (GenericDeviceTy &GenericDevice,
707
+ uint32_t ThreadLimitClause[3 ]) const override {
708
+ if (isNoLoopMode () || isBigJumpLoopMode () || isXTeamReductionsMode ())
709
+ return ConstWGSize;
710
+
711
+ assert (ThreadLimitClause[1 ] == 0 && ThreadLimitClause[2 ] == 0 &&
712
+ " Multi dimensional launch not supported yet." );
713
+
714
+ if (ThreadLimitClause[0 ] > 0 && isGenericMode ()) {
715
+ if (ThreadLimitClause[0 ] == (uint32_t )-1 )
716
+ ThreadLimitClause[0 ] = PreferredNumThreads;
717
+ else
718
+ ThreadLimitClause[0 ] += GenericDevice.getWarpSize ();
719
+ }
720
+
721
+ return std::min (MaxNumThreads, (ThreadLimitClause[0 ] > 0 )
722
+ ? ThreadLimitClause[0 ]
723
+ : PreferredNumThreads);
724
+ }
725
+ uint64_t getNumBlocks (GenericDeviceTy &GenericDevice,
726
+ uint32_t NumTeamsClause[3 ], uint64_t LoopTripCount,
727
+ uint32_t NumThreads) const override {
728
+ assert (NumTeamsClause[1 ] == 0 && NumTeamsClause[2 ] == 0 &&
729
+ " Multi dimensional launch not supported yet." );
730
+
731
+ const auto getNumGroupsFromThreadsAndTripCount =
732
+ [](const uint64_t TripCount, const uint32_t NumThreads) {
733
+ return ((TripCount - 1 ) / NumThreads) + 1 ;
734
+ };
735
+ uint64_t DeviceNumCUs = GenericDevice.getNumComputeUnits (); // FIXME
736
+
737
+ if (isNoLoopMode ()) {
738
+ return LoopTripCount > 0 ? getNumGroupsFromThreadsAndTripCount (
739
+ LoopTripCount, NumThreads)
740
+ : 1 ;
741
+ }
742
+
743
+ if (isBigJumpLoopMode ()) {
744
+ uint64_t NumGroups = 1 ;
745
+ // Cannot assert a non-zero tripcount. Instead, launch with 1 team if the
746
+ // tripcount is indeed zero.
747
+ if (LoopTripCount > 0 )
748
+ NumGroups =
749
+ getNumGroupsFromThreadsAndTripCount (LoopTripCount, NumThreads);
750
+
751
+ // Honor num_teams clause but lower it if tripcount dictates to
752
+ if (NumTeamsClause[0 ] > 0 &&
753
+ NumTeamsClause[0 ] <= GenericDevice.getBlockLimit ()) {
754
+ NumGroups =
755
+ std::min (static_cast <uint64_t >(NumTeamsClause[0 ]), NumGroups);
756
+ } else {
757
+ // num_teams clause is not specified. Choose lower of tripcount-based
758
+ // num-groups and a value that maximizes occupancy. At this point, aim
759
+ // to have 16 wavefronts in a CU.
760
+ // TODO: This logic needs to be moved to the AMDGPU plugin.
761
+ uint64_t NumWavesInGroup = NumThreads / GenericDevice.getWarpSize ();
762
+ uint64_t MaxOccupancyFactor =
763
+ NumWavesInGroup ? (16 / NumWavesInGroup) : 16 ;
764
+ NumGroups = std::min (NumGroups, MaxOccupancyFactor * DeviceNumCUs);
765
+ }
766
+ return NumGroups;
767
+ }
768
+
769
+ if (isXTeamReductionsMode ()) {
770
+ uint64_t NumGroups = 0 ;
771
+ if (NumTeamsClause[0 ] > 0 &&
772
+ NumTeamsClause[0 ] <= GenericDevice.getBlockLimit ()) {
773
+ NumGroups = NumTeamsClause[0 ];
774
+ } else {
775
+ // If num_teams clause is not specified, we allow a max of 2*CU teams
776
+ if (NumThreads > 0 ) {
777
+ const uint64_t UIntTwo = 2 ;
778
+ NumGroups =
779
+ DeviceNumCUs *
780
+ std::min (UIntTwo, static_cast <uint64_t >(1024 / NumThreads));
781
+ } else {
782
+ NumGroups = DeviceNumCUs;
783
+ }
784
+ // Ensure we don't have a large number of teams running if the tripcount
785
+ // is low
786
+ uint64_t NumGroupsFromTripCount = 1 ;
787
+ if (LoopTripCount > 0 )
788
+ NumGroupsFromTripCount =
789
+ getNumGroupsFromThreadsAndTripCount (LoopTripCount, NumThreads);
790
+ NumGroups = std::min (NumGroups, NumGroupsFromTripCount);
791
+ }
792
+ // For now, we don't allow number of teams beyond 512.
793
+ uint64_t fiveTwelve = 512 ;
794
+ NumGroups = std::min (fiveTwelve, NumGroups);
795
+ return NumGroups;
796
+ }
797
+
798
+ if (NumTeamsClause[0 ] > 0 ) {
799
+ // TODO: We need to honor any value and consequently allow more than the
800
+ // block limit. For this we might need to start multiple kernels or let
801
+ // the blocks start again until the requested number has been started.
802
+ return std::min (NumTeamsClause[0 ], GenericDevice.getBlockLimit ());
803
+ }
804
+
805
+ uint64_t TripCountNumBlocks = std::numeric_limits<uint64_t >::max ();
806
+ if (LoopTripCount > 0 ) {
807
+ if (isSPMDMode ()) {
808
+ // We have a combined construct, i.e. `target teams distribute
809
+ // parallel for [simd]`. We launch so many teams so that each thread
810
+ // will execute one iteration of the loop. round up to the nearest
811
+ // integer
812
+ TripCountNumBlocks = ((LoopTripCount - 1 ) / NumThreads) + 1 ;
813
+ } else {
814
+ assert ((isGenericMode () || isGenericSPMDMode ()) &&
815
+ " Unexpected execution mode!" );
816
+ // If we reach this point, then we have a non-combined construct, i.e.
817
+ // `teams distribute` with a nested `parallel for` and each team is
818
+ // assigned one iteration of the `distribute` loop. E.g.:
819
+ //
820
+ // #pragma omp target teams distribute
821
+ // for(...loop_tripcount...) {
822
+ // #pragma omp parallel for
823
+ // for(...) {}
824
+ // }
825
+ //
826
+ // Threads within a team will execute the iterations of the `parallel`
827
+ // loop.
828
+ TripCountNumBlocks = LoopTripCount;
829
+ }
830
+ }
831
+ // If the loops are long running we rather reuse blocks than spawn too many.
832
+ uint32_t PreferredNumBlocks = std::min (uint32_t (TripCountNumBlocks),
833
+ getDefaultNumBlocks (GenericDevice));
834
+ return std::min (PreferredNumBlocks, GenericDevice.getBlockLimit ());
835
+ }
703
836
};
704
837
705
838
// / Class representing an HSA signal. Signals are used to define dependencies
0 commit comments