Skip to content

Commit 35f656b

Browse files
authored
[SYCL] Optimize setNDRangeDescriptor functions (#18132)
1 parent f805102 commit 35f656b

File tree

2 files changed

+25
-31
lines changed

2 files changed

+25
-31
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3687,11 +3687,13 @@ class __SYCL_EXPORT handler {
36873687
bool HasAssociatedAccessor(detail::AccessorImplHost *Req,
36883688
access::target AccessTarget) const;
36893689

3690-
template <int Dims> static sycl::range<3> padRange(sycl::range<Dims> Range) {
3690+
template <int Dims>
3691+
static sycl::range<3> padRange(sycl::range<Dims> Range,
3692+
[[maybe_unused]] size_t DefaultValue = 0) {
36913693
if constexpr (Dims == 3) {
36923694
return Range;
36933695
} else {
3694-
sycl::range<3> Res{0, 0, 0};
3696+
sycl::range<3> Res{DefaultValue, DefaultValue, DefaultValue};
36953697
for (int I = 0; I < Dims; ++I)
36963698
Res[I] = Range[I];
36973699
return Res;
@@ -3712,7 +3714,8 @@ class __SYCL_EXPORT handler {
37123714
template <int Dims>
37133715
void setNDRangeDescriptor(sycl::range<Dims> N,
37143716
bool SetNumWorkGroups = false) {
3715-
return setNDRangeDescriptorPadded(padRange(N), SetNumWorkGroups, Dims);
3717+
return setNDRangeDescriptorPadded(padRange(N, SetNumWorkGroups ? 0 : 1),
3718+
SetNumWorkGroups, Dims);
37163719
}
37173720
template <int Dims>
37183721
void setNDRangeDescriptor(sycl::range<Dims> NumWorkItems,
@@ -3722,9 +3725,10 @@ class __SYCL_EXPORT handler {
37223725
}
37233726
template <int Dims>
37243727
void setNDRangeDescriptor(sycl::nd_range<Dims> ExecutionRange) {
3728+
sycl::range<Dims> LocalRange = ExecutionRange.get_local_range();
37253729
return setNDRangeDescriptorPadded(
3726-
padRange(ExecutionRange.get_global_range()),
3727-
padRange(ExecutionRange.get_local_range()),
3730+
padRange(ExecutionRange.get_global_range(), 1),
3731+
padRange(LocalRange, LocalRange[0] ? 1 : 0),
37283732
padId(ExecutionRange.get_offset()), Dims);
37293733
}
37303734

sycl/source/detail/cg.hpp

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,13 @@ class ArgDesc {
6363
// The structure represents NDRange - global, local sizes, global offset and
6464
// number of dimensions.
6565
class NDRDescT {
66-
// The method initializes all sizes for dimensions greater than the passed one
67-
// to the default values, so they will not affect execution.
68-
void setNDRangeLeftover() {
69-
for (int I = Dims; I < 3; ++I) {
70-
GlobalSize[I] = 1;
71-
LocalSize[I] = LocalSize[0] ? 1 : 0;
72-
GlobalOffset[I] = 0;
73-
NumWorkGroups[I] = 0;
74-
}
75-
}
76-
77-
template <int Dims> static sycl::range<3> padRange(sycl::range<Dims> Range) {
66+
template <int Dims>
67+
static sycl::range<3> padRange(sycl::range<Dims> Range,
68+
[[maybe_unused]] size_t DefaultValue = 0) {
7869
if constexpr (Dims == 3) {
7970
return Range;
8071
} else {
81-
sycl::range<3> Res{0, 0, 0};
72+
sycl::range<3> Res{DefaultValue, DefaultValue, DefaultValue};
8273
for (int I = 0; I < Dims; ++I)
8374
Res[I] = Range[I];
8475
return Res;
@@ -102,37 +93,36 @@ class NDRDescT {
10293
NDRDescT(NDRDescT &&Desc) = default;
10394

10495
NDRDescT(sycl::range<3> N, bool SetNumWorkGroups, int DimsArg)
105-
: GlobalSize{SetNumWorkGroups ? sycl::range<3>{0, 0, 0} : N},
106-
NumWorkGroups{SetNumWorkGroups ? N : sycl::range<3>{0, 0, 0}},
107-
Dims{size_t(DimsArg)} {
108-
setNDRangeLeftover();
96+
: Dims{size_t(DimsArg)} {
97+
if (SetNumWorkGroups) {
98+
NumWorkGroups = N;
99+
} else {
100+
GlobalSize = N;
101+
}
109102
}
110103

111104
NDRDescT(sycl::range<3> NumWorkItems, sycl::range<3> LocalSize,
112105
sycl::id<3> Offset, int DimsArg)
113106
: GlobalSize{NumWorkItems}, LocalSize{LocalSize}, GlobalOffset{Offset},
114-
Dims{size_t(DimsArg)} {
115-
setNDRangeLeftover();
116-
}
107+
Dims{size_t(DimsArg)} {}
117108

118109
NDRDescT(sycl::range<3> NumWorkItems, sycl::id<3> Offset, int DimsArg)
119110
: GlobalSize{NumWorkItems}, GlobalOffset{Offset}, Dims{size_t(DimsArg)} {}
120111

121112
template <int Dims_>
122113
NDRDescT(sycl::nd_range<Dims_> ExecutionRange, int DimsArg)
123-
: NDRDescT(padRange(ExecutionRange.get_global_range()),
124-
padRange(ExecutionRange.get_local_range()),
125-
padId(ExecutionRange.get_offset()), size_t(DimsArg)) {
126-
setNDRangeLeftover();
127-
}
114+
: NDRDescT(padRange(ExecutionRange.get_global_range(), 1),
115+
padRange(ExecutionRange.get_local_range(),
116+
ExecutionRange.get_local_range()[0] ? 1 : 0),
117+
padId(ExecutionRange.get_offset()), size_t(DimsArg)) {}
128118

129119
template <int Dims_>
130120
NDRDescT(sycl::nd_range<Dims_> ExecutionRange)
131121
: NDRDescT(ExecutionRange, Dims_) {}
132122

133123
template <int Dims_>
134124
NDRDescT(sycl::range<Dims_> Range)
135-
: NDRDescT(padRange(Range), /*SetNumWorkGroups=*/false, Dims_) {}
125+
: NDRDescT(padRange(Range, 1), /*SetNumWorkGroups=*/false, Dims_) {}
136126

137127
void setClusterDimensions(sycl::range<3> N, int Dims) {
138128
if (this->Dims != size_t(Dims)) {

0 commit comments

Comments
 (0)