Skip to content

Commit 65e0031

Browse files
authored
[flang][cuda] Adapt TargetRewrite to support gpu.launch_func (#119933)
1 parent 1ab81f8 commit 65e0031

File tree

2 files changed

+55
-8
lines changed

2 files changed

+55
-8
lines changed

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
134134
mod.walk([&](mlir::Operation *op) {
135135
if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
136136
if (!hasPortableSignature(call.getFunctionType(), op))
137-
convertCallOp(call);
137+
convertCallOp(call, call.getFunctionType());
138138
} else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) {
139139
if (!hasPortableSignature(dispatch.getFunctionType(), op))
140-
convertCallOp(dispatch);
140+
convertCallOp(dispatch, dispatch.getFunctionType());
141+
} else if (auto gpuLaunchFunc =
142+
mlir::dyn_cast<mlir::gpu::LaunchFuncOp>(op)) {
143+
llvm::SmallVector<mlir::Type> operandsTypes;
144+
for (auto arg : gpuLaunchFunc.getKernelOperands())
145+
operandsTypes.push_back(arg.getType());
146+
auto fctTy = mlir::FunctionType::get(&context, operandsTypes, {});
147+
if (!hasPortableSignature(fctTy, op))
148+
convertCallOp(gpuLaunchFunc, fctTy);
141149
} else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) {
142150
if (mlir::isa<mlir::FunctionType>(addr.getType()) &&
143151
!hasPortableSignature(addr.getType(), op))
@@ -357,8 +365,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
357365

358366
// Convert fir.call and fir.dispatch Ops.
359367
template <typename A>
360-
void convertCallOp(A callOp) {
361-
auto fnTy = callOp.getFunctionType();
368+
void convertCallOp(A callOp, mlir::FunctionType fnTy) {
362369
auto loc = callOp.getLoc();
363370
rewriter->setInsertionPoint(callOp);
364371
llvm::SmallVector<mlir::Type> newResTys;
@@ -376,7 +383,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
376383
newOpers.push_back(callOp.getOperand(0));
377384
dropFront = 1;
378385
}
379-
} else {
386+
} else if constexpr (std::is_same_v<std::decay_t<A>, fir::DispatchOp>) {
380387
dropFront = 1; // First operand is the polymorphic object.
381388
}
382389

@@ -402,10 +409,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
402409

403410
llvm::SmallVector<mlir::Type> trailingInTys;
404411
llvm::SmallVector<mlir::Value> trailingOpers;
412+
llvm::SmallVector<mlir::Value> operands;
405413
unsigned passArgShift = 0;
414+
if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>)
415+
operands = callOp.getKernelOperands();
416+
else
417+
operands = callOp.getOperands().drop_front(dropFront);
406418
for (auto e : llvm::enumerate(
407-
llvm::zip(fnTy.getInputs().drop_front(dropFront),
408-
callOp.getOperands().drop_front(dropFront)))) {
419+
llvm::zip(fnTy.getInputs().drop_front(dropFront), operands))) {
409420
mlir::Type ty = std::get<0>(e.value());
410421
mlir::Value oper = std::get<1>(e.value());
411422
unsigned index = e.index();
@@ -507,7 +518,19 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
507518
newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
508519

509520
llvm::SmallVector<mlir::Value, 1> newCallResults;
510-
if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
521+
if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) {
522+
auto newCall = rewriter->create<A>(
523+
loc, callOp.getKernel(), callOp.getGridSizeOperandValues(),
524+
callOp.getBlockSizeOperandValues(),
525+
callOp.getDynamicSharedMemorySize(), newOpers);
526+
if (callOp.getClusterSizeX())
527+
newCall.getClusterSizeXMutable().assign(callOp.getClusterSizeX());
528+
if (callOp.getClusterSizeY())
529+
newCall.getClusterSizeYMutable().assign(callOp.getClusterSizeY());
530+
if (callOp.getClusterSizeZ())
531+
newCall.getClusterSizeZMutable().assign(callOp.getClusterSizeZ());
532+
newCallResults.append(newCall.result_begin(), newCall.result_end());
533+
} else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
511534
fir::CallOp newCall;
512535
if (callOp.getCallee()) {
513536
newCall =

flang/test/Fir/CUDA/cuda-target-rewrite.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,27 @@ gpu.module @testmod {
2727
// CHECK-LABEL: gpu.func @_QPtest
2828
// CHECK-SAME: (%arg0: f64, %arg1: f64) -> tuple<f64, f64> {
2929
// CHECK: gpu.return %{{.*}} : tuple<f64, f64>
30+
31+
32+
// -----
33+
module attributes {gpu.container_module} {
34+
35+
gpu.module @testmod {
36+
gpu.func @_QPtest(%arg0: complex<f64>) -> () kernel {
37+
gpu.return
38+
}
39+
}
40+
41+
func.func @main(%arg0: complex<f64>) {
42+
%0 = llvm.mlir.constant(0 : i64) : i64
43+
%1 = llvm.mlir.constant(0 : i32) : i32
44+
gpu.launch_func @testmod::@_QPtest blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64 dynamic_shared_memory_size %1 args(%arg0 : complex<f64>)
45+
return
46+
}
47+
48+
}
49+
50+
// CHECK-LABEL: gpu.func @_QPtest
51+
// CHECK-SAME: (%arg0: f64, %arg1: f64) kernel {
52+
// CHECK: gpu.return
53+
// CHECK: gpu.launch_func @testmod::@_QPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i64 dynamic_shared_memory_size %{{.*}} args(%{{.*}} : f64, %{{.*}} : f64)

0 commit comments

Comments
 (0)