Skip to content

Commit 4f30e66

Browse files
[SYCL] Enable range rounding for unnamed lambdas (#10736)
1 parent b50da1e commit 4f30e66

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,8 @@ class __SYCL_EXPORT handler {
10631063

10641064
// Disable the rounding-up optimizations under these conditions:
10651065
// 1. The env var SYCL_DISABLE_PARALLEL_FOR_RANGE_ROUNDING is set.
1066-
// 2. The kernel is provided via an interoperability method.
1066+
// 2. The kernel is provided via an interoperability method (this uses a
1067+
// different code path).
10671068
// 3. The range is already a multiple of the rounding factor.
10681069
//
10691070
// Cases 2 and 3 could be supported with extra effort.
@@ -1076,17 +1077,10 @@ class __SYCL_EXPORT handler {
10761077
// call-graph to make this_item calls kernel-specific but this is
10771078
// not considered worthwhile.
10781079

1079-
// Get the kernel name to check condition 2.
1080-
std::string KName = typeid(NameT *).name();
1081-
using KI = detail::KernelInfo<KernelName>;
1082-
bool DisableRounding =
1083-
this->DisableRangeRounding() ||
1084-
(KI::getName() == nullptr || KI::getName()[0] == '\0');
1085-
10861080
// Perform range rounding if rounding-up is enabled
10871081
// and there are sufficient work-items to need rounding
10881082
// and the user-specified range is not a multiple of a "good" value.
1089-
if (!DisableRounding && (NumWorkItems[0] >= MinRangeX) &&
1083+
if (!this->DisableRangeRounding() && (NumWorkItems[0] >= MinRangeX) &&
10901084
(NumWorkItems[0] % MinFactorX != 0)) {
10911085
// It is sufficient to round up just the first dimension.
10921086
// Multiplying the rounded-up value of the first dimension

sycl/test-e2e/Basic/parallel_for_range_roundup.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,27 @@ void try_id3(size_t size) {
145145
check("Counter = ", Counter, size * 10 * 10);
146146
}
147147

148+
void try_unnamed_lambda(size_t size) {
149+
range<3> Size{size, 10, 10};
150+
int Counter = 0;
151+
{
152+
buffer<range<3>, 1> BufRange(&Range3, 1);
153+
buffer<int, 1> BufCounter(&Counter, 1);
154+
queue myQueue;
155+
156+
myQueue.submit([&](handler &cgh) {
157+
auto AccRange = BufRange.get_access<access::mode::read_write>(cgh);
158+
auto AccCounter = BufCounter.get_access<access::mode::atomic>(cgh);
159+
cgh.parallel_for(Size, [=](id<3> ID) {
160+
AccCounter[0].fetch_add(1);
161+
AccRange[0][0] = ID[0];
162+
});
163+
});
164+
myQueue.wait();
165+
}
166+
check("Counter = ", Counter, size * 10 * 10);
167+
}
168+
148169
int main() {
149170
int x;
150171

@@ -155,6 +176,7 @@ int main() {
155176
try_id1(x);
156177
try_id2(x);
157178
try_id3(x);
179+
try_unnamed_lambda(x);
158180

159181
x = 256;
160182
try_item1(x);
@@ -163,6 +185,7 @@ int main() {
163185
try_id1(x);
164186
try_id2(x);
165187
try_id3(x);
188+
try_unnamed_lambda(x);
166189

167190
return 0;
168191
}
@@ -182,6 +205,8 @@ int main() {
182205
// CHECK-NEXT: Counter = 15000
183206
// CHECK-NEXT: parallel_for range adjusted from 1500 to 1504
184207
// CHECK-NEXT: Counter = 150000
208+
// CHECK-NEXT: parallel_for range adjusted from 1500 to 1504
209+
// CHECK-NEXT: Counter = 150000
185210
// CHECK-NEXT: Size seen by user = 256
186211
// CHECK-NEXT: Counter = 256
187212
// CHECK-NEXT: Size seen by user = 256
@@ -191,3 +216,4 @@ int main() {
191216
// CHECK-NEXT: Counter = 256
192217
// CHECK-NEXT: Counter = 2560
193218
// CHECK-NEXT: Counter = 25600
219+
// CHECK-NEXT: Counter = 25600

0 commit comments

Comments
 (0)