-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang][cuda] Adapt TargetRewrite to support gpu.launch_func #119933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesThe gpu.func are already supported in the TargetRewrite pass. Update the pass to also support rewriting the gpu.launch_func operation. Full diff: https://github.com/llvm/llvm-project/pull/119933.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 5a042b34a58c0a..b0b9499557e2b7 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -134,10 +134,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
mod.walk([&](mlir::Operation *op) {
if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
if (!hasPortableSignature(call.getFunctionType(), op))
- convertCallOp(call);
+ convertCallOp(call, call.getFunctionType());
} else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) {
if (!hasPortableSignature(dispatch.getFunctionType(), op))
- convertCallOp(dispatch);
+ convertCallOp(dispatch, dispatch.getFunctionType());
+ } else if (auto gpuLaunchFunc =
+ mlir::dyn_cast<mlir::gpu::LaunchFuncOp>(op)) {
+ llvm::SmallVector<mlir::Type> operandsTypes;
+ for (auto arg : gpuLaunchFunc.getKernelOperands())
+ operandsTypes.push_back(arg.getType());
+ auto fctTy = mlir::FunctionType::get(&context, operandsTypes, {});
+ if (!hasPortableSignature(fctTy, op))
+ convertCallOp(gpuLaunchFunc, fctTy);
} else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) {
if (mlir::isa<mlir::FunctionType>(addr.getType()) &&
!hasPortableSignature(addr.getType(), op))
@@ -357,8 +365,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
// Convert fir.call and fir.dispatch Ops.
template <typename A>
- void convertCallOp(A callOp) {
- auto fnTy = callOp.getFunctionType();
+ void convertCallOp(A callOp, mlir::FunctionType fnTy) {
auto loc = callOp.getLoc();
rewriter->setInsertionPoint(callOp);
llvm::SmallVector<mlir::Type> newResTys;
@@ -376,7 +383,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
newOpers.push_back(callOp.getOperand(0));
dropFront = 1;
}
- } else {
+ } else if constexpr (std::is_same_v<std::decay_t<A>, fir::DispatchOp>) {
dropFront = 1; // First operand is the polymorphic object.
}
@@ -402,10 +409,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
llvm::SmallVector<mlir::Type> trailingInTys;
llvm::SmallVector<mlir::Value> trailingOpers;
+ llvm::SmallVector<mlir::Value> operands;
unsigned passArgShift = 0;
+ if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>)
+ operands = callOp.getKernelOperands();
+ else
+ operands = callOp.getOperands().drop_front(dropFront);
for (auto e : llvm::enumerate(
- llvm::zip(fnTy.getInputs().drop_front(dropFront),
- callOp.getOperands().drop_front(dropFront)))) {
+ llvm::zip(fnTy.getInputs().drop_front(dropFront), operands))) {
mlir::Type ty = std::get<0>(e.value());
mlir::Value oper = std::get<1>(e.value());
unsigned index = e.index();
@@ -507,7 +518,19 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
llvm::SmallVector<mlir::Value, 1> newCallResults;
- if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
+ if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) {
+ auto newCall = rewriter->create<A>(
+ loc, callOp.getKernel(), callOp.getGridSizeOperandValues(),
+ callOp.getBlockSizeOperandValues(),
+ callOp.getDynamicSharedMemorySize(), newOpers);
+ if (callOp.getClusterSizeX())
+ newCall.getClusterSizeXMutable().assign(callOp.getClusterSizeX());
+ if (callOp.getClusterSizeY())
+ newCall.getClusterSizeYMutable().assign(callOp.getClusterSizeY());
+ if (callOp.getClusterSizeZ())
+ newCall.getClusterSizeZMutable().assign(callOp.getClusterSizeZ());
+ newCallResults.append(newCall.result_begin(), newCall.result_end());
+ } else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
fir::CallOp newCall;
if (callOp.getCallee()) {
newCall =
diff --git a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
index d88b6776795a0b..0e7534e06c89c9 100644
--- a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
+++ b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
@@ -27,3 +27,27 @@ gpu.module @testmod {
// CHECK-LABEL: gpu.func @_QPtest
// CHECK-SAME: (%arg0: f64, %arg1: f64) -> tuple<f64, f64> {
// CHECK: gpu.return %{{.*}} : tuple<f64, f64>
+
+
+// -----
+module attributes {gpu.container_module} {
+
+gpu.module @testmod {
+ gpu.func @_QPtest(%arg0: complex<f64>) -> () kernel {
+ gpu.return
+ }
+}
+
+func.func @main(%arg0: complex<f64>) {
+ %0 = llvm.mlir.constant(0 : i64) : i64
+ %1 = llvm.mlir.constant(0 : i32) : i32
+ gpu.launch_func @testmod::@_QPtest blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64 dynamic_shared_memory_size %1 args(%arg0 : complex<f64>)
+ return
+}
+
+}
+
+// CHECK-LABEL: gpu.func @_QPtest
+// CHECK-SAME: (%arg0: f64, %arg1: f64) kernel {
+// CHECK: gpu.return
+// CHECK: gpu.launch_func @testmod::@_QPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i64 dynamic_shared_memory_size %{{.*}} args(%{{.*}} : f64, %{{.*}} : f64)
|
@llvm/pr-subscribers-flang-codegen Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesThe gpu.func are already supported in the TargetRewrite pass. Update the pass to also support rewriting the gpu.launch_func operation. Full diff: https://github.com/llvm/llvm-project/pull/119933.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 5a042b34a58c0a..b0b9499557e2b7 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -134,10 +134,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
mod.walk([&](mlir::Operation *op) {
if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
if (!hasPortableSignature(call.getFunctionType(), op))
- convertCallOp(call);
+ convertCallOp(call, call.getFunctionType());
} else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) {
if (!hasPortableSignature(dispatch.getFunctionType(), op))
- convertCallOp(dispatch);
+ convertCallOp(dispatch, dispatch.getFunctionType());
+ } else if (auto gpuLaunchFunc =
+ mlir::dyn_cast<mlir::gpu::LaunchFuncOp>(op)) {
+ llvm::SmallVector<mlir::Type> operandsTypes;
+ for (auto arg : gpuLaunchFunc.getKernelOperands())
+ operandsTypes.push_back(arg.getType());
+ auto fctTy = mlir::FunctionType::get(&context, operandsTypes, {});
+ if (!hasPortableSignature(fctTy, op))
+ convertCallOp(gpuLaunchFunc, fctTy);
} else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) {
if (mlir::isa<mlir::FunctionType>(addr.getType()) &&
!hasPortableSignature(addr.getType(), op))
@@ -357,8 +365,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
// Convert fir.call and fir.dispatch Ops.
template <typename A>
- void convertCallOp(A callOp) {
- auto fnTy = callOp.getFunctionType();
+ void convertCallOp(A callOp, mlir::FunctionType fnTy) {
auto loc = callOp.getLoc();
rewriter->setInsertionPoint(callOp);
llvm::SmallVector<mlir::Type> newResTys;
@@ -376,7 +383,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
newOpers.push_back(callOp.getOperand(0));
dropFront = 1;
}
- } else {
+ } else if constexpr (std::is_same_v<std::decay_t<A>, fir::DispatchOp>) {
dropFront = 1; // First operand is the polymorphic object.
}
@@ -402,10 +409,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
llvm::SmallVector<mlir::Type> trailingInTys;
llvm::SmallVector<mlir::Value> trailingOpers;
+ llvm::SmallVector<mlir::Value> operands;
unsigned passArgShift = 0;
+ if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>)
+ operands = callOp.getKernelOperands();
+ else
+ operands = callOp.getOperands().drop_front(dropFront);
for (auto e : llvm::enumerate(
- llvm::zip(fnTy.getInputs().drop_front(dropFront),
- callOp.getOperands().drop_front(dropFront)))) {
+ llvm::zip(fnTy.getInputs().drop_front(dropFront), operands))) {
mlir::Type ty = std::get<0>(e.value());
mlir::Value oper = std::get<1>(e.value());
unsigned index = e.index();
@@ -507,7 +518,19 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
llvm::SmallVector<mlir::Value, 1> newCallResults;
- if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
+ if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) {
+ auto newCall = rewriter->create<A>(
+ loc, callOp.getKernel(), callOp.getGridSizeOperandValues(),
+ callOp.getBlockSizeOperandValues(),
+ callOp.getDynamicSharedMemorySize(), newOpers);
+ if (callOp.getClusterSizeX())
+ newCall.getClusterSizeXMutable().assign(callOp.getClusterSizeX());
+ if (callOp.getClusterSizeY())
+ newCall.getClusterSizeYMutable().assign(callOp.getClusterSizeY());
+ if (callOp.getClusterSizeZ())
+ newCall.getClusterSizeZMutable().assign(callOp.getClusterSizeZ());
+ newCallResults.append(newCall.result_begin(), newCall.result_end());
+ } else if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
fir::CallOp newCall;
if (callOp.getCallee()) {
newCall =
diff --git a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
index d88b6776795a0b..0e7534e06c89c9 100644
--- a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
+++ b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
@@ -27,3 +27,27 @@ gpu.module @testmod {
// CHECK-LABEL: gpu.func @_QPtest
// CHECK-SAME: (%arg0: f64, %arg1: f64) -> tuple<f64, f64> {
// CHECK: gpu.return %{{.*}} : tuple<f64, f64>
+
+
+// -----
+module attributes {gpu.container_module} {
+
+gpu.module @testmod {
+ gpu.func @_QPtest(%arg0: complex<f64>) -> () kernel {
+ gpu.return
+ }
+}
+
+func.func @main(%arg0: complex<f64>) {
+ %0 = llvm.mlir.constant(0 : i64) : i64
+ %1 = llvm.mlir.constant(0 : i32) : i32
+ gpu.launch_func @testmod::@_QPtest blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64 dynamic_shared_memory_size %1 args(%arg0 : complex<f64>)
+ return
+}
+
+}
+
+// CHECK-LABEL: gpu.func @_QPtest
+// CHECK-SAME: (%arg0: f64, %arg1: f64) kernel {
+// CHECK: gpu.return
+// CHECK: gpu.launch_func @testmod::@_QPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i64 dynamic_shared_memory_size %{{.*}} args(%{{.*}} : f64, %{{.*}} : f64)
|
The gpu.func are already supported in the TargetRewrite pass. Update the pass to also support rewriting the gpu.launch_func operation.