@@ -203,10 +203,17 @@ static Promotion getInternalizationInfo(Requirement *Req) {
203
203
return (AccPromotion != Promotion::None) ? AccPromotion : BuffPromotion;
204
204
}
205
205
206
- static std::optional<size_t > getLocalSize (NDRDescT NDRange, Requirement *Req,
207
- Promotion Target) {
206
+ static std::optional<size_t > getLocalSize (NDRDescT NDRange,
207
+ std::optional<size_t > UserGlobalSize,
208
+ Requirement *Req, Promotion Target) {
209
+ assert ((!UserGlobalSize.has_value () || Target != Promotion::Local) &&
210
+ " Unexpected range rounding" );
208
211
auto NumElementsMem = static_cast <SYCLMemObjT *>(Req->MSYCLMemObj )->size ();
209
212
if (Target == Promotion::Private) {
213
+ if (UserGlobalSize.has_value ()) {
214
+ // Only the first dimension is affected by range rounding.
215
+ NDRange.GlobalSize [0 ] = *UserGlobalSize;
216
+ }
210
217
auto NumWorkItems = NDRange.GlobalSize .size ();
211
218
// For private internalization, the local size is
212
219
// (Number of elements in buffer)/(number of work-items)
@@ -237,13 +244,15 @@ static bool accessorEquals(Requirement *Req, Requirement *Other) {
237
244
238
245
static void resolveInternalization (ArgDesc &Arg, unsigned KernelIndex,
239
246
unsigned ArgFunctionIndex, NDRDescT NDRange,
247
+ std::optional<size_t > UserGlobalSize,
240
248
PromotionMap &Promotions) {
241
249
assert (Arg.MType == kernel_param_kind_t ::kind_accessor);
242
250
243
251
Requirement *Req = static_cast <Requirement *>(Arg.MPtr );
244
252
245
253
auto ThisPromotionTarget = getInternalizationInfo (Req);
246
- auto ThisLocalSize = getLocalSize (NDRange, Req, ThisPromotionTarget);
254
+ auto ThisLocalSize =
255
+ getLocalSize (NDRange, UserGlobalSize, Req, ThisPromotionTarget);
247
256
248
257
if (Promotions.count (Req->MSYCLMemObj )) {
249
258
// We previously encountered an accessor for the same buffer.
@@ -278,7 +287,7 @@ static void resolveInternalization(ArgDesc &Arg, unsigned KernelIndex,
278
287
// Recompute the local size for the previous definition with adapted
279
288
// promotion target.
280
289
auto NewPrevLocalSize =
281
- getLocalSize (PreviousDefinition.NDRange ,
290
+ getLocalSize (PreviousDefinition.NDRange , std::nullopt,
282
291
PreviousDefinition.Definition , Promotion::Local);
283
292
284
293
if (!NewPrevLocalSize.has_value ()) {
@@ -316,7 +325,8 @@ static void resolveInternalization(ArgDesc &Arg, unsigned KernelIndex,
316
325
317
326
if (PreviousDefinition.PromotionTarget == Promotion::Local) {
318
327
// Recompute the local size with adapted promotion target.
319
- auto ThisLocalSize = getLocalSize (NDRange, Req, Promotion::Local);
328
+ auto ThisLocalSize =
329
+ getLocalSize (NDRange, std::nullopt, Req, Promotion::Local);
320
330
if (!ThisLocalSize.has_value ()) {
321
331
printPerformanceWarning (" Work-group size for local promotion not "
322
332
" specified, not performing internalization" );
@@ -591,11 +601,12 @@ updatePromotedArgs(const ::jit_compiler::SYCLKernelInfo &FusedKernelInfo,
591
601
// argument is later on passed to the kernel.
592
602
const size_t SizeAccField =
593
603
sizeof (size_t ) * (Req->MDims == 0 ? 1 : Req->MDims );
594
- // Compute the local size and use it for the range parameters.
595
- auto LocalSize = getLocalSize (NDRange, Req,
596
- (PromotedToPrivate) ? Promotion::Private
597
- : Promotion::Local);
598
- range<3 > AccessRange{1 , 1 , LocalSize.value ()};
604
+ // Compute the local size and use it for the range parameters (only
605
+ // relevant for local promotion).
606
+ size_t LocalSize = PromotedToLocal ? *getLocalSize (NDRange, std::nullopt,
607
+ Req, Promotion::Local)
608
+ : 0 ;
609
+ range<3 > AccessRange{1 , 1 , LocalSize};
599
610
auto *RangeArg = storePlainArg (FusedArgStorage, AccessRange);
600
611
// Use all-zero as the offset
601
612
id<3 > AcessOffset{0 , 0 , 0 };
@@ -604,7 +615,7 @@ updatePromotedArgs(const ::jit_compiler::SYCLKernelInfo &FusedKernelInfo,
604
615
// Override the arguments.
605
616
// 1. Override the pointer with a std-layout argument with 'nullptr' as
606
617
// value. handler.cpp does the same for local accessors.
607
- int SizeInBytes = Req->MElemSize * LocalSize. value () ;
618
+ int SizeInBytes = Req->MElemSize * LocalSize;
608
619
FusedArgs[ArgIndex] =
609
620
ArgDesc{kernel_param_kind_t ::kind_std_layout, nullptr , SizeInBytes,
610
621
static_cast <int >(ArgIndex)};
@@ -694,6 +705,26 @@ jit_compiler::fuseKernels(QueueImplPtr Queue,
694
705
return A.MIndex < B.MIndex ;
695
706
});
696
707
708
+ // Determine whether the kernel has been subject to DPCPP's range rounding.
709
+ // If so, the first argument will be the original ("user") range.
710
+ std::optional<size_t > UserGlobalSize;
711
+ if ((KernelName.find (" _ZTSN4sycl3_V16detail18RoundedRangeKernel" ) == 0 ||
712
+ KernelName.find (" _ZTSN4sycl3_V16detail19__pf_kernel_wrapper" ) == 0 ) &&
713
+ !Args.empty ()) {
714
+ auto &A0 = Args[0 ];
715
+ auto Dims = KernelCG->MNDRDesc .Dims ;
716
+ assert (A0.MPtr && A0.MSize == static_cast <int >(Dims * sizeof (size_t )) &&
717
+ A0.MType == kernel_param_kind_t ::kind_std_layout &&
718
+ " Unexpected signature for rounded range kernel" );
719
+
720
+ size_t *UGS = reinterpret_cast <size_t *>(A0.MPtr );
721
+ // Range-rounding only applies to the first dimension.
722
+ assert (UGS[0 ] > KernelCG->MNDRDesc .GlobalSize [1 ]);
723
+ assert (Dims < 2 || UGS[1 ] == KernelCG->MNDRDesc .GlobalSize [1 ]);
724
+ assert (Dims < 3 || UGS[2 ] == KernelCG->MNDRDesc .GlobalSize [2 ]);
725
+ UserGlobalSize = UGS[0 ];
726
+ }
727
+
697
728
::jit_compiler::SYCLArgumentDescriptor ArgDescriptor{Args.size ()};
698
729
size_t ArgIndex = 0 ;
699
730
// The kernel function in SPIR-V will only have the non-eliminated
@@ -719,7 +750,8 @@ jit_compiler::fuseKernels(QueueImplPtr Queue,
719
750
if (!Eliminated) {
720
751
if (Arg.MType == kernel_param_kind_t ::kind_accessor) {
721
752
resolveInternalization (Arg, KernelIndex, ArgFunctionIndex,
722
- KernelCG->MNDRDesc , PromotedAccs);
753
+ KernelCG->MNDRDesc , UserGlobalSize,
754
+ PromotedAccs);
723
755
}
724
756
FusedParams.emplace_back (Arg, KernelIndex, ArgFunctionIndex, true );
725
757
++ArgFunctionIndex;
0 commit comments