Skip to content

Commit 9a3240e

Browse files
Get grid/block sizes from gpu.launch_func
1 parent 9ac4a93 commit 9a3240e

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

lib/gc/Transforms/GPU/GpuToGpuOcl.cpp

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -398,38 +398,37 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
398398
IntegerAttr::get(helper.idxType,
399399
static_cast<int64_t>(binaryAttr.size())));
400400

401-
SmallVector<int32_t> globalSize;
402-
SmallVector<int32_t> localSize;
403-
SmallVector<int32_t> argSize;
404-
kernelMod->walk([&](gpu::GPUFuncOp func) {
405-
if (func.getName() == gpuLaunch.getKernelName()) {
406-
for (auto s : func.getKnownGridSize().value()) {
407-
globalSize.emplace_back(s);
408-
}
409-
for (auto s : func.getKnownBlockSize().value()) {
410-
localSize.emplace_back(s);
411-
}
412-
}
413-
});
414-
assert(globalSize.size() == 3 && localSize.size() == 3);
415-
globalSize = {globalSize[0] * localSize[0], globalSize[1] * localSize[1],
416-
globalSize[2] * localSize[2]};
401+
SmallVector<Value> gridSize;
402+
SmallVector<Value> blockSize;
403+
SmallVector<Value> argSize;
404+
gridSize.emplace_back(gpuLaunch.getGridSizeX());
405+
gridSize.emplace_back(gpuLaunch.getGridSizeY());
406+
gridSize.emplace_back(gpuLaunch.getGridSizeZ());
407+
blockSize.emplace_back(gpuLaunch.getBlockSizeX());
408+
blockSize.emplace_back(gpuLaunch.getBlockSizeY());
409+
blockSize.emplace_back(gpuLaunch.getBlockSizeZ());
410+
417411
for (auto arg : adaptor.getKernelOperands()) {
418412
auto type = arg.getType();
419413
auto size = type.isIntOrFloat() ? type.getIntOrFloatBitWidth() / 8 : 0;
420-
argSize.emplace_back(size);
414+
argSize.emplace_back(helper.idxConstant(rewriter, loc, size));
421415
}
422416

423-
auto array = [&](SmallVector<int32_t> &values) {
417+
auto array = [&](SmallVector<Value> &values) {
424418
auto size = helper.idxConstant(rewriter, loc, values.size());
425419
auto arrayPtr = rewriter.create<LLVM::AllocaOp>(loc, helper.ptrType,
426420
helper.idxType, size);
427421
for (size_t i = 0, n = values.size(); i < n; i++) {
428422
auto elementPtr = rewriter.create<LLVM::GEPOp>(
429423
loc, helper.ptrType, helper.idxType, arrayPtr,
430424
helper.idxConstant(rewriter, loc, i));
431-
rewriter.create<LLVM::StoreOp>(
432-
loc, helper.idxConstant(rewriter, loc, values[i]), elementPtr);
425+
auto value = values[i];
426+
if (auto cast = value.getDefiningOp<UnrealizedConversionCastOp>()) {
427+
assert(getConstantIntValue(cast.getOperand(0)));
428+
value = helper.idxConstant(
429+
rewriter, loc, getConstantIntValue(cast.getOperand(0)).value());
430+
}
431+
rewriter.create<LLVM::StoreOp>(loc, value, elementPtr);
433432
}
434433
return arrayPtr.getResult();
435434
};
@@ -442,8 +441,8 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
442441
{helper.ptrType, helper.idxType, helper.ptrType, helper.ptrType,
443442
helper.ptrType, helper.ptrType, helper.idxType, helper.ptrType},
444443
loc,
445-
{ctx, spirvSize, spirv, name, array(globalSize), array(localSize),
446-
argNum, array(argSize)});
444+
{ctx, spirvSize, spirv, name, array(gridSize), array(blockSize), argNum,
445+
array(argSize)});
447446
auto result = createKernelCall.getResult();
448447

449448
// Save the kernel pointer to the global var using CAS

0 commit comments

Comments
 (0)