@@ -134,10 +134,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
134
134
mod.walk ([&](mlir::Operation *op) {
135
135
if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
136
136
if (!hasPortableSignature (call.getFunctionType (), op))
137
- convertCallOp (call);
137
+ convertCallOp (call, call. getFunctionType () );
138
138
} else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) {
139
139
if (!hasPortableSignature (dispatch.getFunctionType (), op))
140
- convertCallOp (dispatch);
140
+ convertCallOp (dispatch, dispatch.getFunctionType ());
141
+ } else if (auto gpuLaunchFunc =
142
+ mlir::dyn_cast<mlir::gpu::LaunchFuncOp>(op)) {
143
+ llvm::SmallVector<mlir::Type> operandsTypes;
144
+ for (auto arg : gpuLaunchFunc.getKernelOperands ())
145
+ operandsTypes.push_back (arg.getType ());
146
+ auto fctTy = mlir::FunctionType::get (&context, operandsTypes, {});
147
+ if (!hasPortableSignature (fctTy, op))
148
+ convertCallOp (gpuLaunchFunc, fctTy);
141
149
} else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) {
142
150
if (mlir::isa<mlir::FunctionType>(addr.getType ()) &&
143
151
!hasPortableSignature (addr.getType (), op))
@@ -357,8 +365,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
357
365
358
366
// Convert fir.call and fir.dispatch Ops.
359
367
template <typename A>
360
- void convertCallOp (A callOp) {
361
- auto fnTy = callOp.getFunctionType ();
368
+ void convertCallOp (A callOp, mlir::FunctionType fnTy) {
362
369
auto loc = callOp.getLoc ();
363
370
rewriter->setInsertionPoint (callOp);
364
371
llvm::SmallVector<mlir::Type> newResTys;
@@ -376,7 +383,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
376
383
newOpers.push_back (callOp.getOperand (0 ));
377
384
dropFront = 1 ;
378
385
}
379
- } else {
386
+ } else if constexpr (std::is_same_v<std:: decay_t <A>, fir::DispatchOp>) {
380
387
dropFront = 1 ; // First operand is the polymorphic object.
381
388
}
382
389
@@ -402,10 +409,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
402
409
403
410
llvm::SmallVector<mlir::Type> trailingInTys;
404
411
llvm::SmallVector<mlir::Value> trailingOpers;
412
+ llvm::SmallVector<mlir::Value> operands;
405
413
unsigned passArgShift = 0 ;
414
+ if constexpr (std::is_same_v<std::decay_t <A>, mlir::gpu::LaunchFuncOp>)
415
+ operands = callOp.getKernelOperands ();
416
+ else
417
+ operands = callOp.getOperands ().drop_front (dropFront);
406
418
for (auto e : llvm::enumerate (
407
- llvm::zip (fnTy.getInputs ().drop_front (dropFront),
408
- callOp.getOperands ().drop_front (dropFront)))) {
419
+ llvm::zip (fnTy.getInputs ().drop_front (dropFront), operands))) {
409
420
mlir::Type ty = std::get<0 >(e.value ());
410
421
mlir::Value oper = std::get<1 >(e.value ());
411
422
unsigned index = e.index ();
@@ -507,7 +518,19 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
507
518
newOpers.insert (newOpers.end (), trailingOpers.begin (), trailingOpers.end ());
508
519
509
520
llvm::SmallVector<mlir::Value, 1 > newCallResults;
510
- if constexpr (std::is_same_v<std::decay_t <A>, fir::CallOp>) {
521
+ if constexpr (std::is_same_v<std::decay_t <A>, mlir::gpu::LaunchFuncOp>) {
522
+ auto newCall = rewriter->create <A>(
523
+ loc, callOp.getKernel (), callOp.getGridSizeOperandValues (),
524
+ callOp.getBlockSizeOperandValues (),
525
+ callOp.getDynamicSharedMemorySize (), newOpers);
526
+ if (callOp.getClusterSizeX ())
527
+ newCall.getClusterSizeXMutable ().assign (callOp.getClusterSizeX ());
528
+ if (callOp.getClusterSizeY ())
529
+ newCall.getClusterSizeYMutable ().assign (callOp.getClusterSizeY ());
530
+ if (callOp.getClusterSizeZ ())
531
+ newCall.getClusterSizeZMutable ().assign (callOp.getClusterSizeZ ());
532
+ newCallResults.append (newCall.result_begin (), newCall.result_end ());
533
+ } else if constexpr (std::is_same_v<std::decay_t <A>, fir::CallOp>) {
511
534
fir::CallOp newCall;
512
535
if (callOp.getCallee ()) {
513
536
newCall =
0 commit comments