@@ -398,38 +398,37 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
398
398
IntegerAttr::get (helper.idxType ,
399
399
static_cast <int64_t >(binaryAttr.size ())));
400
400
401
- SmallVector<int32_t > globalSize;
402
- SmallVector<int32_t > localSize;
403
- SmallVector<int32_t > argSize;
404
- kernelMod->walk ([&](gpu::GPUFuncOp func) {
405
- if (func.getName () == gpuLaunch.getKernelName ()) {
406
- for (auto s : func.getKnownGridSize ().value ()) {
407
- globalSize.emplace_back (s);
408
- }
409
- for (auto s : func.getKnownBlockSize ().value ()) {
410
- localSize.emplace_back (s);
411
- }
412
- }
413
- });
414
- assert (globalSize.size () == 3 && localSize.size () == 3 );
415
- globalSize = {globalSize[0 ] * localSize[0 ], globalSize[1 ] * localSize[1 ],
416
- globalSize[2 ] * localSize[2 ]};
401
+ SmallVector<Value> gridSize;
402
+ SmallVector<Value> blockSize;
403
+ SmallVector<Value> argSize;
404
+ gridSize.emplace_back (gpuLaunch.getGridSizeX ());
405
+ gridSize.emplace_back (gpuLaunch.getGridSizeY ());
406
+ gridSize.emplace_back (gpuLaunch.getGridSizeZ ());
407
+ blockSize.emplace_back (gpuLaunch.getBlockSizeX ());
408
+ blockSize.emplace_back (gpuLaunch.getBlockSizeY ());
409
+ blockSize.emplace_back (gpuLaunch.getBlockSizeZ ());
410
+
417
411
for (auto arg : adaptor.getKernelOperands ()) {
418
412
auto type = arg.getType ();
419
413
auto size = type.isIntOrFloat () ? type.getIntOrFloatBitWidth () / 8 : 0 ;
420
- argSize.emplace_back (size);
414
+ argSize.emplace_back (helper. idxConstant (rewriter, loc, size) );
421
415
}
422
416
423
- auto array = [&](SmallVector<int32_t > &values) {
417
+ auto array = [&](SmallVector<Value > &values) {
424
418
auto size = helper.idxConstant (rewriter, loc, values.size ());
425
419
auto arrayPtr = rewriter.create <LLVM::AllocaOp>(loc, helper.ptrType ,
426
420
helper.idxType , size);
427
421
for (size_t i = 0 , n = values.size (); i < n; i++) {
428
422
auto elementPtr = rewriter.create <LLVM::GEPOp>(
429
423
loc, helper.ptrType , helper.idxType , arrayPtr,
430
424
helper.idxConstant (rewriter, loc, i));
431
- rewriter.create <LLVM::StoreOp>(
432
- loc, helper.idxConstant (rewriter, loc, values[i]), elementPtr);
425
+ auto value = values[i];
426
+ if (auto cast = value.getDefiningOp <UnrealizedConversionCastOp>()) {
427
+ assert (getConstantIntValue (cast.getOperand (0 )));
428
+ value = helper.idxConstant (
429
+ rewriter, loc, getConstantIntValue (cast.getOperand (0 )).value ());
430
+ }
431
+ rewriter.create <LLVM::StoreOp>(loc, value, elementPtr);
433
432
}
434
433
return arrayPtr.getResult ();
435
434
};
@@ -442,8 +441,8 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
442
441
{helper.ptrType , helper.idxType , helper.ptrType , helper.ptrType ,
443
442
helper.ptrType , helper.ptrType , helper.idxType , helper.ptrType },
444
443
loc,
445
- {ctx, spirvSize, spirv, name, array (globalSize ), array (localSize) ,
446
- argNum, array (argSize)});
444
+ {ctx, spirvSize, spirv, name, array (gridSize ), array (blockSize), argNum ,
445
+ array (argSize)});
447
446
auto result = createKernelCall.getResult ();
448
447
449
448
// Save the kernel pointer to the global var using CAS
0 commit comments