@@ -748,9 +748,39 @@ void SetArgBasedOnType(
748
748
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
749
749
const sycl::context &Context, detail::ArgDesc &Arg, size_t NextTrueIndex);
750
750
751
- void applyFuncOnFilteredArgs (
752
- const KernelArgMask *EliminatedArgMask, std::vector<ArgDesc> &Args,
753
- std::function<void (detail::ArgDesc &Arg, int NextTrueIndex)> Func);
751
+ template <typename FuncT>
752
+ void applyFuncOnFilteredArgs (const KernelArgMask *EliminatedArgMask,
753
+ std::vector<ArgDesc> &Args, FuncT Func) {
754
+ if (!EliminatedArgMask || EliminatedArgMask->size () == 0 ) {
755
+ for (ArgDesc &Arg : Args) {
756
+ Func (Arg, Arg.MIndex );
757
+ }
758
+ } else {
759
+ // TODO this is not necessary as long as we can guarantee that the
760
+ // arguments are already sorted (e. g. handle the sorting in handler
761
+ // if necessary due to set_arg(...) usage).
762
+ std::sort (Args.begin (), Args.end (), [](const ArgDesc &A, const ArgDesc &B) {
763
+ return A.MIndex < B.MIndex ;
764
+ });
765
+ int LastIndex = -1 ;
766
+ size_t NextTrueIndex = 0 ;
767
+
768
+ for (ArgDesc &Arg : Args) {
769
+ // Handle potential gaps in set arguments (e. g. if some of them are
770
+ // set on the user side).
771
+ for (int Idx = LastIndex + 1 ; Idx < Arg.MIndex ; ++Idx)
772
+ if (!(*EliminatedArgMask)[Idx])
773
+ ++NextTrueIndex;
774
+ LastIndex = Arg.MIndex ;
775
+
776
+ if ((*EliminatedArgMask)[Arg.MIndex ])
777
+ continue ;
778
+
779
+ Func (Arg, NextTrueIndex);
780
+ ++NextTrueIndex;
781
+ }
782
+ }
783
+ }
754
784
755
785
void ReverseRangeDimensionsForKernel (NDRDescT &NDR);
756
786
0 commit comments