Skip to content

Commit 6eb5208

Browse files
committed
Add two new properties to ur_kernel_group_info_t
These two properties allow the program to specify a maximum work-group size in various ways. They are intended to be targeted from languages such as SYCL (see intel/llvm#14518). This PR implements them for CUDA and Native CPU. It should also be able support them for HIP, in the same fashion. Other adapters using SPIR-V and/or Level Zero would require further changes to both of those specifications.
1 parent 45ad7c5 commit 6eb5208

File tree

23 files changed

+243
-47
lines changed

23 files changed

+243
-47
lines changed

include/ur_api.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4862,6 +4862,10 @@ typedef enum ur_kernel_group_info_t {
48624862
UR_KERNEL_GROUP_INFO_PREFERRED_WORK_GROUP_SIZE_MULTIPLE = 4, ///< [size_t] Return preferred multiple of Work Group size for launch
48634863
UR_KERNEL_GROUP_INFO_PRIVATE_MEM_SIZE = 5, ///< [size_t] Return minimum amount of private memory in bytes used by each
48644864
///< work item in the Kernel
4865+
UR_KERNEL_GROUP_INFO_COMPILE_MAX_WORK_GROUP_SIZE = 6, ///< [size_t[3]] Return the maximum Work Group size guaranteed by the
4866+
///< source code, or (0, 0, 0) if unspecified
4867+
UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE = 7, ///< [size_t] Return the maximum linearized Work Group size (X * Y * Z)
4868+
///< guaranteed by the source code, or 0 if unspecified
48654869
/// @cond
48664870
UR_KERNEL_GROUP_INFO_FORCE_UINT32 = 0x7fffffff
48674871
/// @endcond
@@ -4965,7 +4969,7 @@ urKernelGetInfo(
49654969
/// + `NULL == hKernel`
49664970
/// + `NULL == hDevice`
49674971
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
4968-
/// + `::UR_KERNEL_GROUP_INFO_PRIVATE_MEM_SIZE < propName`
4972+
/// + `::UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE < propName`
49694973
UR_APIEXPORT ur_result_t UR_APICALL
49704974
urKernelGetGroupInfo(
49714975
ur_kernel_handle_t hKernel, ///< [in] handle of the Kernel object

include/ur_print.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7994,6 +7994,12 @@ inline std::ostream &operator<<(std::ostream &os, enum ur_kernel_group_info_t va
79947994
case UR_KERNEL_GROUP_INFO_PRIVATE_MEM_SIZE:
79957995
os << "UR_KERNEL_GROUP_INFO_PRIVATE_MEM_SIZE";
79967996
break;
7997+
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_WORK_GROUP_SIZE:
7998+
os << "UR_KERNEL_GROUP_INFO_COMPILE_MAX_WORK_GROUP_SIZE";
7999+
break;
8000+
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE:
8001+
os << "UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE";
8002+
break;
79978003
default:
79988004
os << "unknown enumerator";
79998005
break;
@@ -8086,6 +8092,32 @@ inline ur_result_t printTagged(std::ostream &os, const void *ptr, ur_kernel_grou
80868092

80878093
os << ")";
80888094
} break;
8095+
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_WORK_GROUP_SIZE: {
8096+
8097+
const size_t *tptr = (const size_t *)ptr;
8098+
os << "{";
8099+
size_t nelems = size / sizeof(size_t);
8100+
for (size_t i = 0; i < nelems; ++i) {
8101+
if (i != 0) {
8102+
os << ", ";
8103+
}
8104+
8105+
os << tptr[i];
8106+
}
8107+
os << "}";
8108+
} break;
8109+
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE: {
8110+
const size_t *tptr = (const size_t *)ptr;
8111+
if (sizeof(size_t) > size) {
8112+
os << "invalid size (is: " << size << ", expected: >=" << sizeof(size_t) << ")";
8113+
return UR_RESULT_ERROR_INVALID_SIZE;
8114+
}
8115+
os << (const void *)(tptr) << " (";
8116+
8117+
os << *tptr;
8118+
8119+
os << ")";
8120+
} break;
80898121
default:
80908122
os << "unknown enumerator";
80918123
return UR_RESULT_ERROR_INVALID_ENUMERATION;

scripts/core/kernel.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,14 @@ etors:
144144
desc: "[size_t] Return preferred multiple of Work Group size for launch"
145145
- name: PRIVATE_MEM_SIZE
146146
desc: "[size_t] Return minimum amount of private memory in bytes used by each work item in the Kernel"
147+
- name: COMPILE_MAX_WORK_GROUP_SIZE
148+
desc: |
149+
[size_t[3]] Return the maximum Work Group size guaranteed by the
150+
source code, or (0, 0, 0) if unspecified
151+
- name: COMPILE_MAX_LINEAR_WORK_GROUP_SIZE
152+
desc: |
153+
[size_t] Return the maximum linearized Work Group size (X * Y * Z)
154+
guaranteed by the source code, or 0 if unspecified
147155
--- #--------------------------------------------------------------------------
148156
type: enum
149157
desc: "Get Kernel SubGroup information"

source/adapters/cuda/enqueue.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ setKernelParams([[maybe_unused]] const ur_context_handle_t Context,
203203
// Set the active context here as guessLocalWorkSize needs an active context
204204
ScopedContext Active(Device);
205205
{
206+
size_t *MaxThreadsPerBlock = Kernel->MaxThreadsPerBlock;
206207
size_t *ReqdThreadsPerBlock = Kernel->ReqdThreadsPerBlock;
207208
MaxWorkGroupSize = Device->getMaxWorkGroupSize();
208209

@@ -212,6 +213,10 @@ setKernelParams([[maybe_unused]] const ur_context_handle_t Context,
212213
LocalWorkSize[Dim] != ReqdThreadsPerBlock[Dim])
213214
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
214215

216+
if (MaxThreadsPerBlock[Dim] != 0 &&
217+
LocalWorkSize[Dim] > MaxThreadsPerBlock[Dim])
218+
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
219+
215220
if (LocalWorkSize[Dim] > Device->getMaxWorkItemSizes(Dim))
216221
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
217222
// Checks that local work sizes are a divisor of the global work sizes
@@ -235,6 +240,12 @@ setKernelParams([[maybe_unused]] const ur_context_handle_t Context,
235240
KernelLocalWorkGroupSize *= LocalWorkSize[Dim];
236241
}
237242

243+
if (size_t MaxLinearThreadsPerBlock = Kernel->MaxLinearThreadsPerBlock;
244+
MaxLinearThreadsPerBlock &&
245+
MaxLinearThreadsPerBlock < KernelLocalWorkGroupSize) {
246+
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
247+
}
248+
238249
if (hasExceededMaxRegistersPerBlock(Device, Kernel,
239250
KernelLocalWorkGroupSize)) {
240251
return UR_RESULT_ERROR_OUT_OF_RESOURCES;

source/adapters/cuda/kernel.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,30 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
125125
&Bytes, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, hKernel->get()));
126126
return ReturnValue(uint64_t(Bytes));
127127
}
128+
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_WORK_GROUP_SIZE: {
129+
size_t MaxGroupSize[3] = {0, 0, 0};
130+
const auto &MaxWGSizeMDMap =
131+
hKernel->getProgram()->KernelMaxWorkGroupSizeMD;
132+
const auto MaxWGSizeMD = MaxWGSizeMDMap.find(hKernel->getName());
133+
if (MaxWGSizeMD != MaxWGSizeMDMap.end()) {
134+
const auto MaxWGSize = MaxWGSizeMD->second;
135+
MaxGroupSize[0] = std::get<0>(MaxWGSize);
136+
MaxGroupSize[1] = std::get<1>(MaxWGSize);
137+
MaxGroupSize[2] = std::get<2>(MaxWGSize);
138+
}
139+
return ReturnValue(MaxGroupSize, 3);
140+
}
141+
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE: {
142+
size_t MaxLinearGroupSize = 0;
143+
const auto &MaxLinearWGSizeMDMap =
144+
hKernel->getProgram()->KernelMaxLinearWorkGroupSizeMD;
145+
const auto MaxLinearWGSizeMD =
146+
MaxLinearWGSizeMDMap.find(hKernel->getName());
147+
if (MaxLinearWGSizeMD != MaxLinearWGSizeMDMap.end()) {
148+
MaxLinearGroupSize = MaxLinearWGSizeMD->second;
149+
}
150+
return ReturnValue(MaxLinearGroupSize);
151+
}
128152
default:
129153
break;
130154
}

source/adapters/cuda/kernel.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ struct ur_kernel_handle_t_ {
4646

4747
static constexpr uint32_t ReqdThreadsPerBlockDimensions = 3u;
4848
size_t ReqdThreadsPerBlock[ReqdThreadsPerBlockDimensions];
49+
size_t MaxThreadsPerBlock[ReqdThreadsPerBlockDimensions];
50+
size_t MaxLinearThreadsPerBlock{0};
4951
int RegsPerThread{0};
5052

5153
/// Structure that holds the arguments to the kernel.
@@ -169,6 +171,18 @@ struct ur_kernel_handle_t_ {
169171
sizeof(ReqdThreadsPerBlock), ReqdThreadsPerBlock, nullptr);
170172
(void)RetError;
171173
assert(RetError == UR_RESULT_SUCCESS);
174+
/// Note: this code assumes that there is only one device per context
175+
RetError = urKernelGetGroupInfo(
176+
this, Program->getDevice(),
177+
UR_KERNEL_GROUP_INFO_COMPILE_MAX_WORK_GROUP_SIZE,
178+
sizeof(MaxThreadsPerBlock), MaxThreadsPerBlock, nullptr);
179+
assert(RetError == UR_RESULT_SUCCESS);
180+
/// Note: this code assumes that there is only one device per context
181+
RetError = urKernelGetGroupInfo(
182+
this, Program->getDevice(),
183+
UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE,
184+
sizeof(MaxLinearThreadsPerBlock), &MaxLinearThreadsPerBlock, nullptr);
185+
assert(RetError == UR_RESULT_SUCCESS);
172186
UR_CHECK_ERROR(
173187
cuFuncGetAttribute(&RegsPerThread, CU_FUNC_ATTRIBUTE_NUM_REGS, Func));
174188
}

source/adapters/cuda/program.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
5454

5555
auto [Prefix, Tag] = splitMetadataName(MetadataElementName);
5656

57-
if (Tag == __SYCL_UR_PROGRAM_METADATA_TAG_REQD_WORK_GROUP_SIZE) {
58-
// If metadata is reqd_work_group_size, record it for the corresponding
59-
// kernel name.
57+
if (Tag == __SYCL_UR_PROGRAM_METADATA_TAG_REQD_WORK_GROUP_SIZE ||
58+
Tag == __SYCL_UR_PROGRAM_METADATA_TAG_MAX_WORK_GROUP_SIZE) {
59+
// If metadata is reqd_work_group_size/max_work_group_size, record it for
60+
// the corresponding kernel name.
6061
size_t MDElemsSize = MetadataElement.size - sizeof(std::uint64_t);
6162

6263
// Expect between 1 and 3 32-bit integer values.
@@ -69,18 +70,23 @@ ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
6970
reinterpret_cast<const char *>(MetadataElement.value.pData) +
7071
sizeof(std::uint64_t);
7172
// Read values and pad with 1's for values not present.
72-
std::uint32_t ReqdWorkGroupElements[] = {1, 1, 1};
73-
std::memcpy(ReqdWorkGroupElements, ValuePtr, MDElemsSize);
74-
KernelReqdWorkGroupSizeMD[Prefix] =
75-
std::make_tuple(ReqdWorkGroupElements[0], ReqdWorkGroupElements[1],
76-
ReqdWorkGroupElements[2]);
73+
std::array<uint32_t, 3> WorkGroupElements = {1, 1, 1};
74+
std::memcpy(WorkGroupElements.data(), ValuePtr, MDElemsSize);
75+
(Tag == __SYCL_UR_PROGRAM_METADATA_TAG_REQD_WORK_GROUP_SIZE
76+
? KernelReqdWorkGroupSizeMD
77+
: KernelMaxWorkGroupSizeMD)[Prefix] =
78+
std::make_tuple(WorkGroupElements[0], WorkGroupElements[1],
79+
WorkGroupElements[2]);
7780
} else if (Tag == __SYCL_UR_PROGRAM_METADATA_GLOBAL_ID_MAPPING) {
7881
const char *MetadataValPtr =
7982
reinterpret_cast<const char *>(MetadataElement.value.pData) +
8083
sizeof(std::uint64_t);
8184
const char *MetadataValPtrEnd =
8285
MetadataValPtr + MetadataElement.size - sizeof(std::uint64_t);
8386
GlobalIDMD[Prefix] = std::string{MetadataValPtr, MetadataValPtrEnd};
87+
} else if (Tag ==
88+
__SYCL_UR_PROGRAM_METADATA_TAG_MAX_LINEAR_WORK_GROUP_SIZE) {
89+
KernelMaxLinearWorkGroupSizeMD[Prefix] = MetadataElement.value.data64;
8490
}
8591
}
8692
return UR_RESULT_SUCCESS;

source/adapters/cuda/program.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ struct ur_program_handle_t_ {
3636
std::unordered_map<std::string, std::tuple<uint32_t, uint32_t, uint32_t>>
3737
KernelReqdWorkGroupSizeMD;
3838
std::unordered_map<std::string, std::string> GlobalIDMD;
39+
std::unordered_map<std::string, std::tuple<uint32_t, uint32_t, uint32_t>>
40+
KernelMaxWorkGroupSizeMD;
41+
std::unordered_map<std::string, uint64_t> KernelMaxLinearWorkGroupSizeMD;
3942

4043
constexpr static size_t MaxLogSize = 8192u;
4144

@@ -45,7 +48,8 @@ struct ur_program_handle_t_ {
4548

4649
ur_program_handle_t_(ur_context_handle_t Context, ur_device_handle_t Device)
4750
: Module{nullptr}, Binary{}, BinarySizeInBytes{0}, RefCount{1},
48-
Context{Context}, Device{Device}, KernelReqdWorkGroupSizeMD{} {
51+
Context{Context}, Device{Device}, KernelReqdWorkGroupSizeMD{},
52+
KernelMaxWorkGroupSizeMD{}, KernelMaxLinearWorkGroupSizeMD{} {
4953
urContextRetain(Context);
5054
urDeviceRetain(Device);
5155
}

source/adapters/hip/kernel.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
127127
&Bytes, HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, hKernel->get()));
128128
return ReturnValue(uint64_t(Bytes));
129129
}
130+
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_WORK_GROUP_SIZE:
131+
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE:
132+
// FIXME: could be added
133+
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
130134
default:
131135
break;
132136
}

source/adapters/level_zero/kernel.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,10 @@ ur_result_t urKernelGetGroupInfo(
856856
case UR_KERNEL_GROUP_INFO_PRIVATE_MEM_SIZE: {
857857
return ReturnValue(uint32_t{Kernel->ZeKernelProperties->privateMemSize});
858858
}
859+
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_WORK_GROUP_SIZE:
860+
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE:
861+
// No corresponding enumeration in Level Zero
862+
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
859863
default: {
860864
logger::error(
861865
"Unknown ParamName in urKernelGetGroupInfo: ParamName={}(0x{})",

source/adapters/level_zero/v2/kernel.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,10 @@ ur_result_t urKernelGetGroupInfo(
417417
auto props = hKernel->getProperties(hDevice);
418418
return returnValue(uint32_t{props.privateMemSize});
419419
}
420+
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_WORK_GROUP_SIZE:
421+
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE:
422+
// No corresponding enumeration in Level Zero
423+
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
420424
default: {
421425
logger::error(
422426
"Unknown ParamName in urKernelGetGroupInfo: ParamName={}(0x{})",

source/adapters/native_cpu/enqueue.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
8181
DIE_NO_IMPLEMENTATION;
8282
}
8383

84-
// Check reqd_work_group_size
85-
if (hKernel->hasReqdWGSize() && pLocalWorkSize != nullptr) {
86-
const auto &Reqd = hKernel->getReqdWGSize();
84+
// Check reqd_work_group_size and other kernel constraints
85+
if (pLocalWorkSize != nullptr) {
86+
uint64_t TotalNumWIs = 1;
8787
for (uint32_t Dim = 0; Dim < workDim; Dim++) {
88-
if (pLocalWorkSize[Dim] != Reqd[Dim]) {
88+
TotalNumWIs *= pLocalWorkSize[Dim];
89+
if (auto Reqd = hKernel->getReqdWGSize();
90+
Reqd && pLocalWorkSize[Dim] != Reqd.value()[Dim]) {
91+
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
92+
}
93+
if (auto MaxWG = hKernel->getMaxWGSize();
94+
MaxWG && pLocalWorkSize[Dim] > MaxWG.value()[Dim]) {
95+
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
96+
}
97+
}
98+
if (auto MaxLinearWG = hKernel->getMaxLinearWGSize()) {
99+
if (TotalNumWIs > MaxLinearWG) {
89100
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
90101
}
91102
}

source/adapters/native_cpu/kernel.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,25 @@ urKernelCreate(ur_program_handle_t hProgram, const char *pKernelName,
3131
ur_kernel_handle_t_ *kernel;
3232

3333
// Set reqd_work_group_size for kernel if needed
34+
std::optional<native_cpu::WGSize_t> ReqdWG;
3435
const auto &ReqdMap = hProgram->KernelReqdWorkGroupSizeMD;
35-
auto ReqdIt = ReqdMap.find(pKernelName);
36-
if (ReqdIt != ReqdMap.end()) {
37-
kernel = new ur_kernel_handle_t_(hProgram, pKernelName, *f, ReqdIt->second);
38-
} else {
39-
kernel = new ur_kernel_handle_t_(hProgram, pKernelName, *f);
36+
if (auto ReqdIt = ReqdMap.find(pKernelName); ReqdIt != ReqdMap.end()) {
37+
ReqdWG = ReqdIt->second;
4038
}
4139

40+
std::optional<native_cpu::WGSize_t> MaxWG;
41+
const auto &MaxMap = hProgram->KernelMaxWorkGroupSizeMD;
42+
if (auto MaxIt = MaxMap.find(pKernelName); MaxIt != MaxMap.end()) {
43+
MaxWG = MaxIt->second;
44+
}
45+
std::optional<uint64_t> MaxLinearWG;
46+
const auto &MaxLinMap = hProgram->KernelMaxLinearWorkGroupSizeMD;
47+
if (auto MaxLIt = MaxLinMap.find(pKernelName); MaxLIt != MaxLinMap.end()) {
48+
MaxLinearWG = MaxLIt->second;
49+
}
50+
kernel = new ur_kernel_handle_t_(hProgram, pKernelName, *f, ReqdWG, MaxWG,
51+
MaxLinearWG);
52+
4253
*phKernel = kernel;
4354

4455
return UR_RESULT_SUCCESS;
@@ -148,6 +159,10 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
148159
int bytes = 0;
149160
return returnValue(static_cast<uint64_t>(bytes));
150161
}
162+
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_WORK_GROUP_SIZE:
163+
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE:
164+
// FIXME: could be added
165+
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
151166

152167
default:
153168
break;

source/adapters/native_cpu/kernel.hpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,14 @@ struct ur_kernel_handle_t_ : RefCounted {
4141

4242
ur_kernel_handle_t_(ur_program_handle_t hProgram, const char *name,
4343
nativecpu_task_t subhandler)
44-
: hProgram(hProgram), _name{name}, _subhandler{std::move(subhandler)},
45-
HasReqdWGSize(false) {}
44+
: hProgram(hProgram), _name{name}, _subhandler{std::move(subhandler)} {}
4645

4746
ur_kernel_handle_t_(const ur_kernel_handle_t_ &other)
4847
: hProgram(other.hProgram), _name(other._name),
4948
_subhandler(other._subhandler), _args(other._args),
5049
_localArgInfo(other._localArgInfo), _localMemPool(other._localMemPool),
5150
_localMemPoolSize(other._localMemPoolSize),
52-
HasReqdWGSize(other.HasReqdWGSize), ReqdWGSize(other.ReqdWGSize) {
51+
ReqdWGSize(other.ReqdWGSize) {
5352
incrementReferenceCount();
5453
}
5554

@@ -60,19 +59,26 @@ struct ur_kernel_handle_t_ : RefCounted {
6059
}
6160
ur_kernel_handle_t_(ur_program_handle_t hProgram, const char *name,
6261
nativecpu_task_t subhandler,
63-
const native_cpu::ReqdWGSize_t &ReqdWGSize)
62+
std::optional<native_cpu::WGSize_t> ReqdWGSize,
63+
std::optional<native_cpu::WGSize_t> MaxWGSize,
64+
std::optional<uint64_t> MaxLinearWGSize)
6465
: hProgram(hProgram), _name{name}, _subhandler{std::move(subhandler)},
65-
HasReqdWGSize(true), ReqdWGSize(ReqdWGSize) {}
66+
ReqdWGSize(ReqdWGSize), MaxWGSize(MaxWGSize),
67+
MaxLinearWGSize(MaxLinearWGSize) {}
6668

6769
ur_program_handle_t hProgram;
6870
std::string _name;
6971
nativecpu_task_t _subhandler;
7072
std::vector<native_cpu::NativeCPUArgDesc> _args;
7173
std::vector<local_arg_info_t> _localArgInfo;
7274

73-
bool hasReqdWGSize() const { return HasReqdWGSize; }
75+
std::optional<native_cpu::WGSize_t> getReqdWGSize() const {
76+
return ReqdWGSize;
77+
}
78+
79+
std::optional<native_cpu::WGSize_t> getMaxWGSize() const { return MaxWGSize; }
7480

75-
const native_cpu::ReqdWGSize_t &getReqdWGSize() const { return ReqdWGSize; }
81+
std::optional<uint64_t> getMaxLinearWGSize() const { return MaxLinearWGSize; }
7682

7783
void updateMemPool(size_t numParallelThreads) {
7884
// compute requested size.
@@ -103,6 +109,7 @@ struct ur_kernel_handle_t_ : RefCounted {
103109
private:
104110
char *_localMemPool = nullptr;
105111
size_t _localMemPoolSize = 0;
106-
bool HasReqdWGSize;
107-
native_cpu::ReqdWGSize_t ReqdWGSize;
112+
std::optional<native_cpu::WGSize_t> ReqdWGSize = std::nullopt;
113+
std::optional<native_cpu::WGSize_t> MaxWGSize = std::nullopt;
114+
std::optional<uint64_t> MaxLinearWGSize = std::nullopt;
108115
};

0 commit comments

Comments
 (0)