Skip to content

Commit 226e80e

Browse files
committed
[UR][CUDA][HIP] Refactor setKernelParams
- Remove unused Context parameters - Avoid unnecessary copy in `guessLocalWorkSize` - Simplify the control flow in setKernelParams - Move cached properties fetching code to constructors - Query HIP for occupancy in `guessLocalWorkSize`
1 parent 420ac96 commit 226e80e

File tree

11 files changed

+151
-210
lines changed

11 files changed

+151
-210
lines changed

unified-runtime/source/adapters/cuda/command_buffer.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
502502
uint32_t LocalSize = hKernel->getLocalSize();
503503
CUfunction CuFunc = hKernel->get();
504504
UR_CHECK_ERROR(setKernelParams(
505-
hCommandBuffer->Context, hCommandBuffer->Device, workDim,
506-
pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize, hKernel, CuFunc,
507-
ThreadsPerBlock, BlocksPerGrid));
505+
hCommandBuffer->Device, workDim, pGlobalWorkOffset, pGlobalWorkSize,
506+
pLocalWorkSize, hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid));
508507

509508
// Set node param structure with the kernel related data
510509
auto &ArgPointers = hKernel->getArgPointers();
@@ -1373,9 +1372,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
13731372
size_t BlocksPerGrid[3] = {1u, 1u, 1u};
13741373
CUfunction CuFunc = KernelData.Kernel->get();
13751374
auto Result = setKernelParams(
1376-
hCommandBuffer->Context, hCommandBuffer->Device, KernelData.WorkDim,
1377-
KernelData.GlobalWorkOffset, KernelData.GlobalWorkSize, LocalWorkSize,
1378-
KernelData.Kernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
1375+
hCommandBuffer->Device, KernelData.WorkDim, KernelData.GlobalWorkOffset,
1376+
KernelData.GlobalWorkSize, LocalWorkSize, KernelData.Kernel, CuFunc,
1377+
ThreadsPerBlock, BlocksPerGrid);
13791378
if (Result != UR_RESULT_SUCCESS) {
13801379
return Result;
13811380
}

unified-runtime/source/adapters/cuda/device.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ struct ur_device_handle_t_ : ur::cuda::handle_base {
150150
return MaxWorkItemSizes[index];
151151
}
152152

153+
const size_t *getMaxWorkItemSizes() const noexcept {
154+
return MaxWorkItemSizes;
155+
}
156+
153157
size_t getMaxWorkGroupSize() const noexcept { return MaxWorkGroupSize; };
154158

155159
size_t getMaxRegsPerBlock() const noexcept { return MaxRegsPerBlock; };

unified-runtime/source/adapters/cuda/enqueue.cpp

Lines changed: 46 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -119,18 +119,14 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
119119
GlobalSizeNormalized[i] = GlobalWorkSize[i];
120120
}
121121

122-
size_t MaxBlockDim[3];
123-
MaxBlockDim[0] = Device->getMaxWorkItemSizes(0);
124-
MaxBlockDim[1] = Device->getMaxWorkItemSizes(1);
125-
MaxBlockDim[2] = Device->getMaxWorkItemSizes(2);
126-
127122
int MinGrid, MaxBlockSize;
128123
UR_CHECK_ERROR(cuOccupancyMaxPotentialBlockSize(
129124
&MinGrid, &MaxBlockSize, Kernel->get(), NULL, Kernel->getLocalSize(),
130-
MaxBlockDim[0]));
125+
Device->getMaxWorkItemSizes(0)));
131126

132127
roundToHighestFactorOfGlobalSizeIn3d(ThreadsPerBlock, GlobalSizeNormalized,
133-
MaxBlockDim, MaxBlockSize);
128+
Device->getMaxWorkItemSizes(),
129+
MaxBlockSize);
134130
}
135131

136132
// Helper to verify out-of-registers case (exceeded block max registers).
@@ -145,7 +141,6 @@ bool hasExceededMaxRegistersPerBlock(ur_device_handle_t Device,
145141

146142
// Helper to compute kernel parameters from workload
147143
// dimensions.
148-
// @param [in] Context handler to the target Context
149144
// @param [in] Device handler to the target Device
150145
// @param [in] WorkDim workload dimension
151146
// @param [in] GlobalWorkOffset pointer workload global offsets
@@ -155,73 +150,56 @@ bool hasExceededMaxRegistersPerBlock(ur_device_handle_t Device,
155150
// @param [out] ThreadsPerBlock Number of threads per block we should run
156151
// @param [out] BlocksPerGrid Number of blocks per grid we should run
157152
ur_result_t
158-
setKernelParams([[maybe_unused]] const ur_context_handle_t Context,
159-
const ur_device_handle_t Device, const uint32_t WorkDim,
153+
setKernelParams(const ur_device_handle_t Device, const uint32_t WorkDim,
160154
const size_t *GlobalWorkOffset, const size_t *GlobalWorkSize,
161155
const size_t *LocalWorkSize, ur_kernel_handle_t &Kernel,
162156
CUfunction &CuFunc, size_t (&ThreadsPerBlock)[3],
163157
size_t (&BlocksPerGrid)[3]) {
164-
size_t MaxWorkGroupSize = 0u;
165-
bool ProvidedLocalWorkGroupSize = LocalWorkSize != nullptr;
166-
167158
try {
168159
// Set the active context here as guessLocalWorkSize needs an active context
169160
ScopedContext Active(Device);
170-
{
171-
size_t *MaxThreadsPerBlock = Kernel->MaxThreadsPerBlock;
172-
size_t *ReqdThreadsPerBlock = Kernel->ReqdThreadsPerBlock;
173-
MaxWorkGroupSize = Device->getMaxWorkGroupSize();
174-
175-
if (ProvidedLocalWorkGroupSize) {
176-
auto IsValid = [&](int Dim) {
177-
if (ReqdThreadsPerBlock[Dim] != 0 &&
178-
LocalWorkSize[Dim] != ReqdThreadsPerBlock[Dim])
179-
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
180-
181-
if (MaxThreadsPerBlock[Dim] != 0 &&
182-
LocalWorkSize[Dim] > MaxThreadsPerBlock[Dim])
183-
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
184-
185-
if (LocalWorkSize[Dim] > Device->getMaxWorkItemSizes(Dim))
186-
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
187-
// Checks that local work sizes are a divisor of the global work sizes
188-
// which includes that the local work sizes are neither larger than
189-
// the global work sizes and not 0.
190-
if (0u == LocalWorkSize[Dim])
191-
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
192-
if (0u != (GlobalWorkSize[Dim] % LocalWorkSize[Dim]))
193-
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
194-
ThreadsPerBlock[Dim] = LocalWorkSize[Dim];
195-
return UR_RESULT_SUCCESS;
196-
};
197-
198-
size_t KernelLocalWorkGroupSize = 1;
199-
for (size_t Dim = 0; Dim < WorkDim; Dim++) {
200-
auto Err = IsValid(Dim);
201-
if (Err != UR_RESULT_SUCCESS)
202-
return Err;
203-
// If no error then compute the total local work size as a product of
204-
// all dims.
205-
KernelLocalWorkGroupSize *= LocalWorkSize[Dim];
206-
}
207161

208-
if (size_t MaxLinearThreadsPerBlock = Kernel->MaxLinearThreadsPerBlock;
209-
MaxLinearThreadsPerBlock &&
210-
MaxLinearThreadsPerBlock < KernelLocalWorkGroupSize) {
162+
if (LocalWorkSize != nullptr) {
163+
size_t KernelLocalWorkGroupSize = 1;
164+
for (size_t i = 0; i < WorkDim; i++) {
165+
if (Kernel->ReqdThreadsPerBlock[i] &&
166+
Kernel->ReqdThreadsPerBlock[i] != LocalWorkSize[i])
211167
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
212-
}
213168

214-
if (hasExceededMaxRegistersPerBlock(Device, Kernel,
215-
KernelLocalWorkGroupSize)) {
216-
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
217-
}
218-
} else {
219-
guessLocalWorkSize(Device, ThreadsPerBlock, GlobalWorkSize, WorkDim,
220-
Kernel);
169+
if (Kernel->MaxThreadsPerBlock[i] &&
170+
Kernel->MaxThreadsPerBlock[i] < LocalWorkSize[i])
171+
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
172+
173+
if (LocalWorkSize[i] > Device->getMaxWorkItemSizes(i))
174+
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
175+
// Checks that local work sizes are a divisor of the global work sizes
176+
// which includes that the local work sizes are neither larger than
177+
// the global work sizes and not 0.
178+
if (0u == LocalWorkSize[i] ||
179+
0u != (GlobalWorkSize[i] % LocalWorkSize[i]))
180+
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
181+
182+
ThreadsPerBlock[i] = LocalWorkSize[i];
183+
184+
// Compute the total local work size as a product of all is.
185+
KernelLocalWorkGroupSize *= LocalWorkSize[i];
221186
}
187+
188+
if (Kernel->MaxLinearThreadsPerBlock &&
189+
Kernel->MaxLinearThreadsPerBlock < KernelLocalWorkGroupSize) {
190+
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
191+
}
192+
193+
if (hasExceededMaxRegistersPerBlock(Device, Kernel,
194+
KernelLocalWorkGroupSize)) {
195+
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
196+
}
197+
} else {
198+
guessLocalWorkSize(Device, ThreadsPerBlock, GlobalWorkSize, WorkDim,
199+
Kernel);
222200
}
223201

224-
if (MaxWorkGroupSize <
202+
if (Device->getMaxWorkGroupSize() <
225203
ThreadsPerBlock[0] * ThreadsPerBlock[1] * ThreadsPerBlock[2]) {
226204
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
227205
}
@@ -407,10 +385,9 @@ enqueueKernelLaunch(ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
407385

408386
// This might return UR_RESULT_ERROR_ADAPTER_SPECIFIC, which cannot be handled
409387
// using the standard UR_CHECK_ERROR
410-
if (ur_result_t Ret =
411-
setKernelParams(hQueue->getContext(), hQueue->Device, workDim,
412-
pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
413-
hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
388+
if (ur_result_t Ret = setKernelParams(
389+
hQueue->Device, workDim, pGlobalWorkOffset, pGlobalWorkSize,
390+
pLocalWorkSize, hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
414391
Ret != UR_RESULT_SUCCESS)
415392
return Ret;
416393

@@ -595,10 +572,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
595572

596573
// This might return UR_RESULT_ERROR_ADAPTER_SPECIFIC, which cannot be handled
597574
// using the standard UR_CHECK_ERROR
598-
if (ur_result_t Ret =
599-
setKernelParams(hQueue->getContext(), hQueue->Device, workDim,
600-
pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
601-
hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
575+
if (ur_result_t Ret = setKernelParams(
576+
hQueue->Device, workDim, pGlobalWorkOffset, pGlobalWorkSize,
577+
pLocalWorkSize, hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
602578
Ret != UR_RESULT_SUCCESS)
603579
return Ret;
604580

unified-runtime/source/adapters/cuda/enqueue.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ bool hasExceededMaxRegistersPerBlock(ur_device_handle_t Device,
5555
size_t BlockSize);
5656

5757
ur_result_t
58-
setKernelParams(const ur_context_handle_t Context,
59-
const ur_device_handle_t Device, const uint32_t WorkDim,
58+
setKernelParams(const ur_device_handle_t Device, const uint32_t WorkDim,
6059
const size_t *GlobalWorkOffset, const size_t *GlobalWorkSize,
6160
const size_t *LocalWorkSize, ur_kernel_handle_t &Kernel,
6261
CUfunction &CuFunc, size_t (&ThreadsPerBlock)[3],

unified-runtime/source/adapters/cuda/kernel.cpp

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,7 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
9090
return ReturnValue(size_t(MaxThreads));
9191
}
9292
case UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE: {
93-
size_t GroupSize[3] = {0, 0, 0};
94-
const auto &ReqdWGSizeMDMap =
95-
hKernel->getProgram()->KernelReqdWorkGroupSizeMD;
96-
const auto ReqdWGSizeMD = ReqdWGSizeMDMap.find(hKernel->getName());
97-
if (ReqdWGSizeMD != ReqdWGSizeMDMap.end()) {
98-
const auto ReqdWGSize = ReqdWGSizeMD->second;
99-
GroupSize[0] = std::get<0>(ReqdWGSize);
100-
GroupSize[1] = std::get<1>(ReqdWGSize);
101-
GroupSize[2] = std::get<2>(ReqdWGSize);
102-
}
103-
return ReturnValue(GroupSize, 3);
93+
return ReturnValue(hKernel->ReqdThreadsPerBlock, 3);
10494
}
10595
case UR_KERNEL_GROUP_INFO_LOCAL_MEM_SIZE: {
10696
// OpenCL LOCAL == CUDA SHARED
@@ -124,28 +114,10 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
124114
return ReturnValue(uint64_t(Bytes));
125115
}
126116
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_WORK_GROUP_SIZE: {
127-
size_t MaxGroupSize[3] = {0, 0, 0};
128-
const auto &MaxWGSizeMDMap =
129-
hKernel->getProgram()->KernelMaxWorkGroupSizeMD;
130-
const auto MaxWGSizeMD = MaxWGSizeMDMap.find(hKernel->getName());
131-
if (MaxWGSizeMD != MaxWGSizeMDMap.end()) {
132-
const auto MaxWGSize = MaxWGSizeMD->second;
133-
MaxGroupSize[0] = std::get<0>(MaxWGSize);
134-
MaxGroupSize[1] = std::get<1>(MaxWGSize);
135-
MaxGroupSize[2] = std::get<2>(MaxWGSize);
136-
}
137-
return ReturnValue(MaxGroupSize, 3);
117+
return ReturnValue(hKernel->MaxThreadsPerBlock, 3);
138118
}
139119
case UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE: {
140-
size_t MaxLinearGroupSize = 0;
141-
const auto &MaxLinearWGSizeMDMap =
142-
hKernel->getProgram()->KernelMaxLinearWorkGroupSizeMD;
143-
const auto MaxLinearWGSizeMD =
144-
MaxLinearWGSizeMDMap.find(hKernel->getName());
145-
if (MaxLinearWGSizeMD != MaxLinearWGSizeMDMap.end()) {
146-
MaxLinearGroupSize = MaxLinearWGSizeMD->second;
147-
}
148-
return ReturnValue(MaxLinearGroupSize);
120+
return ReturnValue(hKernel->MaxLinearThreadsPerBlock);
149121
}
150122
default:
151123
break;

unified-runtime/source/adapters/cuda/kernel.hpp

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -258,25 +258,43 @@ struct ur_kernel_handle_t_ : ur::cuda::handle_base {
258258
Context{Context}, Program{Program}, RefCount{1} {
259259
urProgramRetain(Program);
260260
urContextRetain(Context);
261-
/// Note: this code assumes that there is only one device per context
262-
ur_result_t RetError = urKernelGetGroupInfo(
263-
this, Program->getDevice(),
264-
UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE,
265-
sizeof(ReqdThreadsPerBlock), ReqdThreadsPerBlock, nullptr);
266-
(void)RetError;
267-
assert(RetError == UR_RESULT_SUCCESS);
268-
/// Note: this code assumes that there is only one device per context
269-
RetError = urKernelGetGroupInfo(
270-
this, Program->getDevice(),
271-
UR_KERNEL_GROUP_INFO_COMPILE_MAX_WORK_GROUP_SIZE,
272-
sizeof(MaxThreadsPerBlock), MaxThreadsPerBlock, nullptr);
273-
assert(RetError == UR_RESULT_SUCCESS);
274-
/// Note: this code assumes that there is only one device per context
275-
RetError = urKernelGetGroupInfo(
276-
this, Program->getDevice(),
277-
UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE,
278-
sizeof(MaxLinearThreadsPerBlock), &MaxLinearThreadsPerBlock, nullptr);
279-
assert(RetError == UR_RESULT_SUCCESS);
261+
262+
// Get reqd work group size
263+
const auto &ReqdWGSizeMDMap = Program->KernelReqdWorkGroupSizeMD;
264+
const auto ReqdWGSizeMD = ReqdWGSizeMDMap.find(Name);
265+
if (ReqdWGSizeMD != ReqdWGSizeMDMap.end()) {
266+
const auto ReqdWGSize = ReqdWGSizeMD->second;
267+
ReqdThreadsPerBlock[0] = std::get<0>(ReqdWGSize);
268+
ReqdThreadsPerBlock[1] = std::get<1>(ReqdWGSize);
269+
ReqdThreadsPerBlock[2] = std::get<2>(ReqdWGSize);
270+
} else {
271+
ReqdThreadsPerBlock[0] = 0;
272+
ReqdThreadsPerBlock[1] = 0;
273+
ReqdThreadsPerBlock[2] = 0;
274+
}
275+
276+
// Get max work group size
277+
const auto &MaxWGSizeMDMap = Program->KernelMaxWorkGroupSizeMD;
278+
const auto MaxWGSizeMD = MaxWGSizeMDMap.find(Name);
279+
if (MaxWGSizeMD != MaxWGSizeMDMap.end()) {
280+
const auto MaxWGSize = MaxWGSizeMD->second;
281+
MaxThreadsPerBlock[0] = std::get<0>(MaxWGSize);
282+
MaxThreadsPerBlock[1] = std::get<1>(MaxWGSize);
283+
MaxThreadsPerBlock[2] = std::get<2>(MaxWGSize);
284+
} else {
285+
MaxThreadsPerBlock[0] = 0;
286+
MaxThreadsPerBlock[1] = 0;
287+
MaxThreadsPerBlock[2] = 0;
288+
}
289+
290+
// Get max linear work group size
291+
MaxLinearThreadsPerBlock = 0;
292+
const auto MaxLinearWGSizeMD =
293+
Program->KernelMaxLinearWorkGroupSizeMD.find(Name);
294+
if (MaxLinearWGSizeMD != Program->KernelMaxLinearWorkGroupSizeMD.end()) {
295+
MaxLinearThreadsPerBlock = MaxLinearWGSizeMD->second;
296+
}
297+
280298
UR_CHECK_ERROR(
281299
cuFuncGetAttribute(&RegsPerThread, CU_FUNC_ATTRIBUTE_NUM_REGS, Func));
282300
}

unified-runtime/source/adapters/hip/device.hpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ struct ur_device_handle_t_ : ur::hip::handle_base {
2828
uint32_t DeviceIndex;
2929

3030
int MaxWorkGroupSize{0};
31-
int MaxBlockDimX{0};
32-
int MaxBlockDimY{0};
33-
int MaxBlockDimZ{0};
31+
size_t MaxBlockDim[3];
3432
int MaxCapacityLocalMem{0};
3533
int MaxChosenLocalMem{0};
3634
int ManagedMemSupport{0};
@@ -45,12 +43,18 @@ struct ur_device_handle_t_ : ur::hip::handle_base {
4543

4644
UR_CHECK_ERROR(hipDeviceGetAttribute(
4745
&MaxWorkGroupSize, hipDeviceAttributeMaxThreadsPerBlock, HIPDevice));
46+
47+
int MaxDim;
4848
UR_CHECK_ERROR(hipDeviceGetAttribute(
49-
&MaxBlockDimX, hipDeviceAttributeMaxBlockDimX, HIPDevice));
49+
&MaxDim, hipDeviceAttributeMaxBlockDimX, HIPDevice));
50+
MaxBlockDim[0] = size_t(MaxDim);
5051
UR_CHECK_ERROR(hipDeviceGetAttribute(
51-
&MaxBlockDimY, hipDeviceAttributeMaxBlockDimY, HIPDevice));
52+
&MaxDim, hipDeviceAttributeMaxBlockDimY, HIPDevice));
53+
MaxBlockDim[1] = size_t(MaxDim);
5254
UR_CHECK_ERROR(hipDeviceGetAttribute(
53-
&MaxBlockDimZ, hipDeviceAttributeMaxBlockDimZ, HIPDevice));
55+
&MaxDim, hipDeviceAttributeMaxBlockDimZ, HIPDevice));
56+
MaxBlockDim[2] = size_t(MaxDim);
57+
5458
UR_CHECK_ERROR(hipDeviceGetAttribute(
5559
&MaxCapacityLocalMem, hipDeviceAttributeMaxSharedMemoryPerBlock,
5660
HIPDevice));
@@ -107,11 +111,9 @@ struct ur_device_handle_t_ : ur::hip::handle_base {
107111

108112
int getMaxWorkGroupSize() const noexcept { return MaxWorkGroupSize; };
109113

110-
int getMaxBlockDimX() const noexcept { return MaxBlockDimX; };
111-
112-
int getMaxBlockDimY() const noexcept { return MaxBlockDimY; };
114+
size_t getMaxBlockDim(int dim) const noexcept { return MaxBlockDim[dim]; };
113115

114-
int getMaxBlockDimZ() const noexcept { return MaxBlockDimZ; };
116+
const size_t *getMaxBlockDim() const noexcept { return MaxBlockDim; };
115117

116118
int getMaxCapacityLocalMem() const noexcept { return MaxCapacityLocalMem; };
117119

0 commit comments

Comments
 (0)