@@ -119,18 +119,14 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
119
119
GlobalSizeNormalized[i] = GlobalWorkSize[i];
120
120
}
121
121
122
- size_t MaxBlockDim[3 ];
123
- MaxBlockDim[0 ] = Device->getMaxWorkItemSizes (0 );
124
- MaxBlockDim[1 ] = Device->getMaxWorkItemSizes (1 );
125
- MaxBlockDim[2 ] = Device->getMaxWorkItemSizes (2 );
126
-
127
122
int MinGrid, MaxBlockSize;
128
123
UR_CHECK_ERROR (cuOccupancyMaxPotentialBlockSize (
129
124
&MinGrid, &MaxBlockSize, Kernel->get (), NULL , Kernel->getLocalSize (),
130
- MaxBlockDim[ 0 ] ));
125
+ Device-> getMaxWorkItemSizes ( 0 ) ));
131
126
132
127
roundToHighestFactorOfGlobalSizeIn3d (ThreadsPerBlock, GlobalSizeNormalized,
133
- MaxBlockDim, MaxBlockSize);
128
+ Device->getMaxWorkItemSizes (),
129
+ MaxBlockSize);
134
130
}
135
131
136
132
// Helper to verify out-of-registers case (exceeded block max registers).
@@ -145,7 +141,6 @@ bool hasExceededMaxRegistersPerBlock(ur_device_handle_t Device,
145
141
146
142
// Helper to compute kernel parameters from workload
147
143
// dimensions.
148
- // @param [in] Context handler to the target Context
149
144
// @param [in] Device handler to the target Device
150
145
// @param [in] WorkDim workload dimension
151
146
// @param [in] GlobalWorkOffset pointer workload global offsets
@@ -155,73 +150,56 @@ bool hasExceededMaxRegistersPerBlock(ur_device_handle_t Device,
155
150
// @param [out] ThreadsPerBlock Number of threads per block we should run
156
151
// @param [out] BlocksPerGrid Number of blocks per grid we should run
157
152
ur_result_t
158
- setKernelParams ([[maybe_unused]] const ur_context_handle_t Context,
159
- const ur_device_handle_t Device, const uint32_t WorkDim,
153
+ setKernelParams (const ur_device_handle_t Device, const uint32_t WorkDim,
160
154
const size_t *GlobalWorkOffset, const size_t *GlobalWorkSize,
161
155
const size_t *LocalWorkSize, ur_kernel_handle_t &Kernel,
162
156
CUfunction &CuFunc, size_t (&ThreadsPerBlock)[3],
163
157
size_t (&BlocksPerGrid)[3]) {
164
- size_t MaxWorkGroupSize = 0u ;
165
- bool ProvidedLocalWorkGroupSize = LocalWorkSize != nullptr ;
166
-
167
158
try {
168
159
// Set the active context here as guessLocalWorkSize needs an active context
169
160
ScopedContext Active (Device);
170
- {
171
- size_t *MaxThreadsPerBlock = Kernel->MaxThreadsPerBlock ;
172
- size_t *ReqdThreadsPerBlock = Kernel->ReqdThreadsPerBlock ;
173
- MaxWorkGroupSize = Device->getMaxWorkGroupSize ();
174
-
175
- if (ProvidedLocalWorkGroupSize) {
176
- auto IsValid = [&](int Dim) {
177
- if (ReqdThreadsPerBlock[Dim] != 0 &&
178
- LocalWorkSize[Dim] != ReqdThreadsPerBlock[Dim])
179
- return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
180
-
181
- if (MaxThreadsPerBlock[Dim] != 0 &&
182
- LocalWorkSize[Dim] > MaxThreadsPerBlock[Dim])
183
- return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
184
-
185
- if (LocalWorkSize[Dim] > Device->getMaxWorkItemSizes (Dim))
186
- return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
187
- // Checks that local work sizes are a divisor of the global work sizes
188
- // which includes that the local work sizes are neither larger than
189
- // the global work sizes and not 0.
190
- if (0u == LocalWorkSize[Dim])
191
- return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
192
- if (0u != (GlobalWorkSize[Dim] % LocalWorkSize[Dim]))
193
- return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
194
- ThreadsPerBlock[Dim] = LocalWorkSize[Dim];
195
- return UR_RESULT_SUCCESS;
196
- };
197
-
198
- size_t KernelLocalWorkGroupSize = 1 ;
199
- for (size_t Dim = 0 ; Dim < WorkDim; Dim++) {
200
- auto Err = IsValid (Dim);
201
- if (Err != UR_RESULT_SUCCESS)
202
- return Err;
203
- // If no error then compute the total local work size as a product of
204
- // all dims.
205
- KernelLocalWorkGroupSize *= LocalWorkSize[Dim];
206
- }
207
161
208
- if (size_t MaxLinearThreadsPerBlock = Kernel->MaxLinearThreadsPerBlock ;
209
- MaxLinearThreadsPerBlock &&
210
- MaxLinearThreadsPerBlock < KernelLocalWorkGroupSize) {
162
+ if (LocalWorkSize != nullptr ) {
163
+ size_t KernelLocalWorkGroupSize = 1 ;
164
+ for (size_t i = 0 ; i < WorkDim; i++) {
165
+ if (Kernel->ReqdThreadsPerBlock [i] &&
166
+ Kernel->ReqdThreadsPerBlock [i] != LocalWorkSize[i])
211
167
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
212
- }
213
168
214
- if (hasExceededMaxRegistersPerBlock (Device, Kernel,
215
- KernelLocalWorkGroupSize)) {
216
- return UR_RESULT_ERROR_OUT_OF_RESOURCES;
217
- }
218
- } else {
219
- guessLocalWorkSize (Device, ThreadsPerBlock, GlobalWorkSize, WorkDim,
220
- Kernel);
169
+ if (Kernel->MaxThreadsPerBlock [i] &&
170
+ Kernel->MaxThreadsPerBlock [i] < LocalWorkSize[i])
171
+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
172
+
173
+ if (LocalWorkSize[i] > Device->getMaxWorkItemSizes (i))
174
+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
175
+ // Checks that local work sizes are a divisor of the global work sizes
176
+ // which includes that the local work sizes are neither larger than
177
+ // the global work sizes and not 0.
178
+ if (0u == LocalWorkSize[i] ||
179
+ 0u != (GlobalWorkSize[i] % LocalWorkSize[i]))
180
+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
181
+
182
+ ThreadsPerBlock[i] = LocalWorkSize[i];
183
+
184
+ // Compute the total local work size as a product of all is.
185
+ KernelLocalWorkGroupSize *= LocalWorkSize[i];
221
186
}
187
+
188
+ if (Kernel->MaxLinearThreadsPerBlock &&
189
+ Kernel->MaxLinearThreadsPerBlock < KernelLocalWorkGroupSize) {
190
+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
191
+ }
192
+
193
+ if (hasExceededMaxRegistersPerBlock (Device, Kernel,
194
+ KernelLocalWorkGroupSize)) {
195
+ return UR_RESULT_ERROR_OUT_OF_RESOURCES;
196
+ }
197
+ } else {
198
+ guessLocalWorkSize (Device, ThreadsPerBlock, GlobalWorkSize, WorkDim,
199
+ Kernel);
222
200
}
223
201
224
- if (MaxWorkGroupSize <
202
+ if (Device-> getMaxWorkGroupSize () <
225
203
ThreadsPerBlock[0 ] * ThreadsPerBlock[1 ] * ThreadsPerBlock[2 ]) {
226
204
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
227
205
}
@@ -407,10 +385,9 @@ enqueueKernelLaunch(ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
407
385
408
386
// This might return UR_RESULT_ERROR_ADAPTER_SPECIFIC, which cannot be handled
409
387
// using the standard UR_CHECK_ERROR
410
- if (ur_result_t Ret =
411
- setKernelParams (hQueue->getContext (), hQueue->Device , workDim,
412
- pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
413
- hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
388
+ if (ur_result_t Ret = setKernelParams (
389
+ hQueue->Device , workDim, pGlobalWorkOffset, pGlobalWorkSize,
390
+ pLocalWorkSize, hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
414
391
Ret != UR_RESULT_SUCCESS)
415
392
return Ret;
416
393
@@ -595,10 +572,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
595
572
596
573
// This might return UR_RESULT_ERROR_ADAPTER_SPECIFIC, which cannot be handled
597
574
// using the standard UR_CHECK_ERROR
598
- if (ur_result_t Ret =
599
- setKernelParams (hQueue->getContext (), hQueue->Device , workDim,
600
- pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
601
- hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
575
+ if (ur_result_t Ret = setKernelParams (
576
+ hQueue->Device , workDim, pGlobalWorkOffset, pGlobalWorkSize,
577
+ pLocalWorkSize, hKernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
602
578
Ret != UR_RESULT_SUCCESS)
603
579
return Ret;
604
580
0 commit comments