@@ -415,11 +415,30 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
415
415
blockSize.emplace_back (gpuLaunch.getBlockSizeY ());
416
416
blockSize.emplace_back (gpuLaunch.getBlockSizeZ ());
417
417
418
- for (auto arg : adaptor .getKernelOperands ()) {
418
+ for (auto arg : gpuLaunch .getKernelOperands ()) {
419
419
auto type = arg.getType ();
420
- // Assuming, that the value is either an integer or a float or a pointer.
421
- // In the latter case, the size is 0 bytes.
422
- auto size = type.isIntOrFloat () ? type.getIntOrFloatBitWidth () / 8 : 0 ;
420
+ size_t size;
421
+ if (isa<MemRefType>(type)) {
422
+ size = 0 ; // A special case for pointers
423
+ } else if (type.isIndex ()) {
424
+ size = helper.idxType .getIntOrFloatBitWidth () / 8 ;
425
+ } else if (type.isIntOrFloat ()) {
426
+ size = type.getIntOrFloatBitWidth () / 8 ;
427
+ } else if (auto vectorType = dyn_cast<VectorType>(type)) {
428
+ type = vectorType.getElementType ();
429
+ if (type.isIntOrFloat ()) {
430
+ size = type.getIntOrFloatBitWidth ();
431
+ } else if (type.isIndex ()) {
432
+ size = helper.idxType .getIntOrFloatBitWidth ();
433
+ } else {
434
+ llvm::errs () << " Unsupported vector element type: " << type << " \n " ;
435
+ return false ;
436
+ }
437
+ size *= vectorType.getNumElements () / 8 ;
438
+ } else {
439
+ llvm::errs () << " Unsupported type: " << type << " \n " ;
440
+ return false ;
441
+ }
423
442
argSize.emplace_back (helper.idxConstant (rewriter, loc, size));
424
443
}
425
444
0 commit comments