@@ -5323,9 +5323,41 @@ piEnqueueKernelLaunch(pi_queue Queue, pi_kernel Kernel, pi_uint32 WorkDim,
5323
5323
WG[1 ] = pi_cast<uint32_t >(LocalWorkSize[1 ]);
5324
5324
WG[2 ] = pi_cast<uint32_t >(LocalWorkSize[2 ]);
5325
5325
} else {
5326
- ZE_CALL (zeKernelSuggestGroupSize,
5327
- (Kernel->ZeKernel , GlobalWorkSize[0 ], GlobalWorkSize[1 ],
5328
- GlobalWorkSize[2 ], &WG[0 ], &WG[1 ], &WG[2 ]));
5326
+ // We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize
5327
+ // values do not fit to 32-bit that the API only supports currently.
5328
+ bool SuggestGroupSize = true ;
5329
+ for (int I : {0 , 1 , 2 }) {
5330
+ if (GlobalWorkSize[I] > UINT32_MAX) {
5331
+ SuggestGroupSize = false ;
5332
+ }
5333
+ }
5334
+ if (SuggestGroupSize) {
5335
+ ZE_CALL (zeKernelSuggestGroupSize,
5336
+ (Kernel->ZeKernel , GlobalWorkSize[0 ], GlobalWorkSize[1 ],
5337
+ GlobalWorkSize[2 ], &WG[0 ], &WG[1 ], &WG[2 ]));
5338
+ } else {
5339
+ for (int I : {0 , 1 , 2 }) {
5340
+ // Try to find a I-dimension WG size that the GlobalWorkSize[I] is
5341
+ // fully divisable with. Start with the max possible size in
5342
+ // each dimension.
5343
+ uint32_t GroupSize[] = {
5344
+ Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeX ,
5345
+ Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeY ,
5346
+ Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeZ };
5347
+ GroupSize[I] = std::min (size_t (GroupSize[I]), GlobalWorkSize[I]);
5348
+ while (GlobalWorkSize[I] % GroupSize[I]) {
5349
+ --GroupSize[I];
5350
+ }
5351
+ if (GlobalWorkSize[I] / GroupSize[I] > UINT32_MAX) {
5352
+ zePrint (" piEnqueueKernelLaunch: can't find a WG size "
5353
+ " suitable for global work size > UINT32_MAX\n " );
5354
+ return PI_ERROR_INVALID_WORK_GROUP_SIZE;
5355
+ }
5356
+ WG[I] = GroupSize[I];
5357
+ }
5358
+ zePrint (" piEnqueueKernelLaunch: using computed WG size = {%d, %d, %d}\n " ,
5359
+ WG[0 ], WG[1 ], WG[2 ]);
5360
+ }
5329
5361
}
5330
5362
5331
5363
// TODO: assert if sizes do not fit into 32-bit?
@@ -5357,17 +5389,20 @@ piEnqueueKernelLaunch(pi_queue Queue, pi_kernel Kernel, pi_uint32 WorkDim,
5357
5389
}
5358
5390
5359
5391
// Error handling for non-uniform group size case
5360
- if (GlobalWorkSize[0 ] != (ZeThreadGroupDimensions.groupCountX * WG[0 ])) {
5392
+ if (GlobalWorkSize[0 ] !=
5393
+ size_t (ZeThreadGroupDimensions.groupCountX ) * WG[0 ]) {
5361
5394
zePrint (" piEnqueueKernelLaunch: invalid work_dim. The range is not a "
5362
5395
" multiple of the group size in the 1st dimension\n " );
5363
5396
return PI_ERROR_INVALID_WORK_GROUP_SIZE;
5364
5397
}
5365
- if (GlobalWorkSize[1 ] != (ZeThreadGroupDimensions.groupCountY * WG[1 ])) {
5398
+ if (GlobalWorkSize[1 ] !=
5399
+ size_t (ZeThreadGroupDimensions.groupCountY ) * WG[1 ]) {
5366
5400
zePrint (" piEnqueueKernelLaunch: invalid work_dim. The range is not a "
5367
5401
" multiple of the group size in the 2nd dimension\n " );
5368
5402
return PI_ERROR_INVALID_WORK_GROUP_SIZE;
5369
5403
}
5370
- if (GlobalWorkSize[2 ] != (ZeThreadGroupDimensions.groupCountZ * WG[2 ])) {
5404
+ if (GlobalWorkSize[2 ] !=
5405
+ size_t (ZeThreadGroupDimensions.groupCountZ ) * WG[2 ]) {
5371
5406
zePrint (" piEnqueueKernelLaunch: invalid work_dim. The range is not a "
5372
5407
" multiple of the group size in the 3rd dimension\n " );
5373
5408
return PI_ERROR_INVALID_WORK_GROUP_SIZE;
0 commit comments