Skip to content

[flang][cuda] Update target rewrite to work on gpu.func #119283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 10, 2024

Conversation

clementval
Copy link
Contributor

Update the pass so it can perform the signature rewrite on gpu.func.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Dec 9, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 9, 2024

@llvm/pr-subscribers-flang-codegen

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Update the pass so it can perform the signature rewrite on gpu.func.


Full diff: https://github.com/llvm/llvm-project/pull/119283.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/TargetRewrite.cpp (+99-64)
  • (modified) flang/test/Fir/CUDA/cuda-target-rewrite.mlir (+13-1)
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 1b86d5241704b1..4603a7d6cb4ec5 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -62,14 +62,21 @@ struct FixupTy {
   FixupTy(Codes code, std::size_t index,
           std::function<void(mlir::func::FuncOp)> &&finalizer)
       : code{code}, index{index}, finalizer{finalizer} {}
+  FixupTy(Codes code, std::size_t index,
+          std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
+      : code{code}, index{index}, gpuFinalizer{finalizer} {}
   FixupTy(Codes code, std::size_t index, std::size_t second,
           std::function<void(mlir::func::FuncOp)> &&finalizer)
       : code{code}, index{index}, second{second}, finalizer{finalizer} {}
+  FixupTy(Codes code, std::size_t index, std::size_t second,
+          std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
+      : code{code}, index{index}, second{second}, gpuFinalizer{finalizer} {}
 
   Codes code;
   std::size_t index;
   std::size_t second{};
   std::optional<std::function<void(mlir::func::FuncOp)>> finalizer{};
+  std::optional<std::function<void(mlir::gpu::GPUFuncOp)>> gpuFinalizer{};
 }; // namespace
 
 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code
@@ -722,9 +729,12 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       convertSignature(fn);
     }
 
-    for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>())
+    for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>()) {
       for (auto fn : gpuMod.getOps<mlir::func::FuncOp>())
         convertSignature(fn);
+      for (auto fn : gpuMod.getOps<mlir::gpu::GPUFuncOp>())
+        convertSignature(fn);
+    }
 
     return mlir::success();
   }
@@ -770,17 +780,20 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
 
   /// Determine if the signature has host associations. The host association
   /// argument may need special target specific rewriting.
-  static bool hasHostAssociations(mlir::func::FuncOp func) {
+  template <typename OpTy>
+  static bool hasHostAssociations(OpTy func) {
     std::size_t end = func.getFunctionType().getInputs().size();
     for (std::size_t i = 0; i < end; ++i)
-      if (func.getArgAttrOfType<mlir::UnitAttr>(i, fir::getHostAssocAttrName()))
+      if (func.template getArgAttrOfType<mlir::UnitAttr>(
+              i, fir::getHostAssocAttrName()))
         return true;
     return false;
   }
 
   /// Rewrite the signatures and body of the `FuncOp`s in the module for
   /// the immediately subsequent target code gen.
-  void convertSignature(mlir::func::FuncOp func) {
+  template <typename OpTy>
+  void convertSignature(OpTy func) {
     auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
     if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func))
       return;
@@ -805,13 +818,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     // Convert return value(s)
     for (auto ty : funcTy.getResults())
       llvm::TypeSwitch<mlir::Type>(ty)
-          .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+          .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
             if (noComplexConversion)
               newResTys.push_back(cmplx);
             else
               doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups);
           })
-          .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+          .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
             auto m = specifics->integerArgumentType(func.getLoc(), intTy);
             assert(m.size() == 1);
             auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -825,10 +838,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
                                                 rewriter->getUnitAttr()));
             newResTys.push_back(retTy);
           })
-          .Case<fir::RecordType>([&](fir::RecordType recTy) {
+          .template Case<fir::RecordType>([&](fir::RecordType recTy) {
             doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups);
           })
-          .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
+          .template Default([&](mlir::Type ty) { newResTys.push_back(ty); });
 
     // Saved potential shift in argument. Handling of result can add arguments
     // at the beginning of the function signature.
@@ -840,7 +853,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       auto ty = e.value();
       unsigned index = e.index();
       llvm::TypeSwitch<mlir::Type>(ty)
-          .Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
+          .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
             if (noCharacterConversion) {
               newInTyAndAttrs.push_back(
                   fir::CodeGenSpecifics::getTypeAndAttr(boxTy));
@@ -863,10 +876,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
               }
             }
           })
-          .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+          .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
             doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
           })
-          .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
+          .template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
             if (fir::isCharacterProcedureTuple(tuple)) {
               fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
                                   newInTyAndAttrs.size(), trailingTys.size());
@@ -878,7 +891,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
                   fir::CodeGenSpecifics::getTypeAndAttr(ty));
             }
           })
-          .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+          .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
             auto m = specifics->integerArgumentType(func.getLoc(), intTy);
             assert(m.size() == 1);
             auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -887,7 +900,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
             if (!extensionAttrName.empty() &&
                 isFuncWithCCallingConvention(func))
               fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo,
-                                  [=](mlir::func::FuncOp func) {
+                                  [=](OpTy func) {
                                     func.setArgAttr(
                                         argNo, extensionAttrName,
                                         mlir::UnitAttr::get(func.getContext()));
@@ -898,13 +911,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
           .template Case<fir::RecordType>([&](fir::RecordType recTy) {
             doStructArg(func, recTy, newInTyAndAttrs, fixups);
           })
-          .Default([&](mlir::Type ty) {
+          .template Default([&](mlir::Type ty) {
             newInTyAndAttrs.push_back(
                 fir::CodeGenSpecifics::getTypeAndAttr(ty));
           });
 
-      if (func.getArgAttrOfType<mlir::UnitAttr>(index,
-                                                fir::getHostAssocAttrName())) {
+      if (func.template getArgAttrOfType<mlir::UnitAttr>(
+              index, fir::getHostAssocAttrName())) {
         extraAttrs.push_back(
             {newInTyAndAttrs.size() - 1,
              rewriter->getNamedAttr("llvm.nest", rewriter->getUnitAttr())});
@@ -979,29 +992,52 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
           auto newArg =
               func.front().insertArgument(fixup.index, fixupType, loc);
           offset++;
-          func.walk([&](mlir::func::ReturnOp ret) {
-            rewriter->setInsertionPoint(ret);
-            auto oldOper = ret.getOperand(0);
-            auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
-            auto cast =
-                rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
-            rewriter->create<fir::StoreOp>(loc, oldOper, cast);
-            rewriter->create<mlir::func::ReturnOp>(loc);
-            ret.erase();
-          });
+          if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
+            func.walk([&](mlir::func::ReturnOp ret) {
+              rewriter->setInsertionPoint(ret);
+              auto oldOper = ret.getOperand(0);
+              auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
+              auto cast =
+                  rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
+              rewriter->create<fir::StoreOp>(loc, oldOper, cast);
+              rewriter->create<mlir::func::ReturnOp>(loc);
+              ret.erase();
+            });
+          if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
+            func.walk([&](mlir::gpu::ReturnOp ret) {
+              rewriter->setInsertionPoint(ret);
+              auto oldOper = ret.getOperand(0);
+              auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
+              auto cast =
+                  rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
+              rewriter->create<fir::StoreOp>(loc, oldOper, cast);
+              rewriter->create<mlir::gpu::ReturnOp>(loc);
+              ret.erase();
+            });
         } break;
         case FixupTy::Codes::ReturnType: {
           // The function is still returning a value, but its type has likely
           // changed to suit the target ABI convention.
-          func.walk([&](mlir::func::ReturnOp ret) {
-            rewriter->setInsertionPoint(ret);
-            auto oldOper = ret.getOperand(0);
-            mlir::Value bitcast =
-                convertValueInMemory(loc, oldOper, newResTys[fixup.index],
-                                     /*inputMayBeBigger=*/false);
-            rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
-            ret.erase();
-          });
+          if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
+            func.walk([&](mlir::func::ReturnOp ret) {
+              rewriter->setInsertionPoint(ret);
+              auto oldOper = ret.getOperand(0);
+              mlir::Value bitcast =
+                  convertValueInMemory(loc, oldOper, newResTys[fixup.index],
+                                       /*inputMayBeBigger=*/false);
+              rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
+              ret.erase();
+            });
+          if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
+            func.walk([&](mlir::gpu::ReturnOp ret) {
+              rewriter->setInsertionPoint(ret);
+              auto oldOper = ret.getOperand(0);
+              mlir::Value bitcast =
+                  convertValueInMemory(loc, oldOper, newResTys[fixup.index],
+                                       /*inputMayBeBigger=*/false);
+              rewriter->create<mlir::gpu::ReturnOp>(loc, bitcast);
+              ret.erase();
+            });
         } break;
         case FixupTy::Codes::Split: {
           // The FIR argument has been split into a pair of distinct arguments
@@ -1101,13 +1137,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       }
     }
 
-    for (auto &fixup : fixups)
-      if (fixup.finalizer)
-        (*fixup.finalizer)(func);
+    for (auto &fixup : fixups) {
+      if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
+        if (fixup.finalizer)
+          (*fixup.finalizer)(func);
+      if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
+        if (fixup.gpuFinalizer)
+          (*fixup.gpuFinalizer)(func);
+    }
   }
 
-  template <typename Ty, typename FIXUPS>
-  void doReturn(mlir::func::FuncOp func, Ty &newResTys,
+  template <typename OpTy, typename Ty, typename FIXUPS>
+  void doReturn(OpTy func, Ty &newResTys,
                 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
                 FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
     assert(m.size() == 1 &&
@@ -1119,7 +1160,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       unsigned argNo = newInTyAndAttrs.size();
       if (auto align = attr.getAlignment())
         fixups.emplace_back(
-            FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
+            FixupTy::Codes::ReturnAsStore, argNo, [=](OpTy func) {
               auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
                   func.getFunctionType().getInput(argNo));
               func.setArgAttr(argNo, "llvm.sret",
@@ -1130,7 +1171,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
             });
       else
         fixups.emplace_back(FixupTy::Codes::ReturnAsStore, argNo,
-                            [=](mlir::func::FuncOp func) {
+                            [=](OpTy func) {
                               auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
                                   func.getFunctionType().getInput(argNo));
                               func.setArgAttr(argNo, "llvm.sret",
@@ -1141,8 +1182,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     }
     if (auto align = attr.getAlignment())
       fixups.emplace_back(
-          FixupTy::Codes::ReturnType, newResTys.size(),
-          [=](mlir::func::FuncOp func) {
+          FixupTy::Codes::ReturnType, newResTys.size(), [=](OpTy func) {
             func.setArgAttr(
                 newResTys.size(), "llvm.align",
                 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align));
@@ -1155,9 +1195,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
   /// Convert a complex return value. This can involve converting the return
   /// value to a "hidden" first argument or packing the complex into a wide
   /// GPR.
-  template <typename Ty, typename FIXUPS>
-  void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx,
-                       Ty &newResTys,
+  template <typename OpTy, typename Ty, typename FIXUPS>
+  void doComplexReturn(OpTy func, mlir::ComplexType cmplx, Ty &newResTys,
                        fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
                        FIXUPS &fixups) {
     if (noComplexConversion) {
@@ -1169,9 +1208,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
   }
 
-  template <typename Ty, typename FIXUPS>
-  void doStructReturn(mlir::func::FuncOp func, fir::RecordType recTy,
-                      Ty &newResTys,
+  template <typename OpTy, typename Ty, typename FIXUPS>
+  void doStructReturn(OpTy func, fir::RecordType recTy, Ty &newResTys,
                       fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
                       FIXUPS &fixups) {
     if (noStructConversion) {
@@ -1182,12 +1220,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
   }
 
-  template <typename FIXUPS>
-  void
-  createFuncOpArgFixups(mlir::func::FuncOp func,
-                        fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
-                        fir::CodeGenSpecifics::Marshalling &argsInTys,
-                        FIXUPS &fixups) {
+  template <typename OpTy, typename FIXUPS>
+  void createFuncOpArgFixups(
+      OpTy func, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
+      fir::CodeGenSpecifics::Marshalling &argsInTys, FIXUPS &fixups) {
     const auto fixupCode = argsInTys.size() > 1 ? FixupTy::Codes::Split
                                                 : FixupTy::Codes::ArgumentType;
     for (auto e : llvm::enumerate(argsInTys)) {
@@ -1198,7 +1234,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       if (attr.isByVal()) {
         if (auto align = attr.getAlignment())
           fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, argNo,
-                              [=](mlir::func::FuncOp func) {
+                              [=](OpTy func) {
                                 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
                                     func.getFunctionType().getInput(argNo));
                                 func.setArgAttr(argNo, "llvm.byval",
@@ -1210,8 +1246,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
                               });
         else
           fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad,
-                              newInTyAndAttrs.size(),
-                              [=](mlir::func::FuncOp func) {
+                              newInTyAndAttrs.size(), [=](OpTy func) {
                                 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
                                     func.getFunctionType().getInput(argNo));
                                 func.setArgAttr(argNo, "llvm.byval",
@@ -1220,7 +1255,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       } else {
         if (auto align = attr.getAlignment())
           fixups.emplace_back(
-              fixupCode, argNo, index, [=](mlir::func::FuncOp func) {
+              fixupCode, argNo, index, [=](OpTy func) {
                 func.setArgAttr(argNo, "llvm.align",
                                 rewriter->getIntegerAttr(
                                     rewriter->getIntegerType(32), align));
@@ -1235,8 +1270,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
   /// Convert a complex argument value. This can involve storing the value to
   /// a temporary memory location or factoring the value into two distinct
   /// arguments.
-  template <typename FIXUPS>
-  void doComplexArg(mlir::func::FuncOp func, mlir::ComplexType cmplx,
+  template <typename OpTy, typename FIXUPS>
+  void doComplexArg(OpTy func, mlir::ComplexType cmplx,
                     fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
                     FIXUPS &fixups) {
     if (noComplexConversion) {
@@ -1248,8 +1283,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     createFuncOpArgFixups(func, newInTyAndAttrs, cplxArgs, fixups);
   }
 
-  template <typename FIXUPS>
-  void doStructArg(mlir::func::FuncOp func, fir::RecordType recTy,
+  template <typename OpTy, typename FIXUPS>
+  void doStructArg(OpTy func, fir::RecordType recTy,
                    fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
                    FIXUPS &fixups) {
     if (noStructConversion) {
diff --git a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
index c14efc8b13f66b..d88b6776795a0b 100644
--- a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
+++ b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
@@ -1,5 +1,5 @@
 // REQUIRES: x86-registered-target
-// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
+// RUN: fir-opt --split-input-file --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
 
 gpu.module @testmod {
   gpu.func @_QPvcpowdk(%arg0: !fir.ref<complex<f64>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
@@ -15,3 +15,15 @@ gpu.module @testmod {
 // CHECK-LABEL: gpu.func @_QPvcpowdk
 // CHECK: %{{.*}} = fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}, %{{.*}}) : (f64, f64, i64) -> tuple<f64, f64>
 // CHECK: func.func private @_FortranAzpowk(f64, f64, i64) -> tuple<f64, f64> attributes {fir.bindc_name = "_FortranAzpowk", fir.runtime}
+
+// -----
+
+gpu.module @testmod {
+  gpu.func @_QPtest(%arg0: complex<f64>) -> (complex<f64>) {
+    gpu.return %arg0 : complex<f64>
+  }
+}
+
+// CHECK-LABEL: gpu.func @_QPtest
+// CHECK-SAME: (%arg0: f64, %arg1: f64) -> tuple<f64, f64> {
+// CHECK: gpu.return %{{.*}} : tuple<f64, f64>

@llvmbot
Copy link
Member

llvmbot commented Dec 9, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Update the pass so it can perform the signature rewrite on gpu.func.


Full diff: https://github.com/llvm/llvm-project/pull/119283.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/TargetRewrite.cpp (+99-64)
  • (modified) flang/test/Fir/CUDA/cuda-target-rewrite.mlir (+13-1)
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 1b86d5241704b1..4603a7d6cb4ec5 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -62,14 +62,21 @@ struct FixupTy {
   FixupTy(Codes code, std::size_t index,
           std::function<void(mlir::func::FuncOp)> &&finalizer)
       : code{code}, index{index}, finalizer{finalizer} {}
+  FixupTy(Codes code, std::size_t index,
+          std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
+      : code{code}, index{index}, gpuFinalizer{finalizer} {}
   FixupTy(Codes code, std::size_t index, std::size_t second,
           std::function<void(mlir::func::FuncOp)> &&finalizer)
       : code{code}, index{index}, second{second}, finalizer{finalizer} {}
+  FixupTy(Codes code, std::size_t index, std::size_t second,
+          std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
+      : code{code}, index{index}, second{second}, gpuFinalizer{finalizer} {}
 
   Codes code;
   std::size_t index;
   std::size_t second{};
   std::optional<std::function<void(mlir::func::FuncOp)>> finalizer{};
+  std::optional<std::function<void(mlir::gpu::GPUFuncOp)>> gpuFinalizer{};
 }; // namespace
 
 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code
@@ -722,9 +729,12 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       convertSignature(fn);
     }
 
-    for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>())
+    for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>()) {
       for (auto fn : gpuMod.getOps<mlir::func::FuncOp>())
         convertSignature(fn);
+      for (auto fn : gpuMod.getOps<mlir::gpu::GPUFuncOp>())
+        convertSignature(fn);
+    }
 
     return mlir::success();
   }
@@ -770,17 +780,20 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
 
   /// Determine if the signature has host associations. The host association
   /// argument may need special target specific rewriting.
-  static bool hasHostAssociations(mlir::func::FuncOp func) {
+  template <typename OpTy>
+  static bool hasHostAssociations(OpTy func) {
     std::size_t end = func.getFunctionType().getInputs().size();
     for (std::size_t i = 0; i < end; ++i)
-      if (func.getArgAttrOfType<mlir::UnitAttr>(i, fir::getHostAssocAttrName()))
+      if (func.template getArgAttrOfType<mlir::UnitAttr>(
+              i, fir::getHostAssocAttrName()))
         return true;
     return false;
   }
 
   /// Rewrite the signatures and body of the `FuncOp`s in the module for
   /// the immediately subsequent target code gen.
-  void convertSignature(mlir::func::FuncOp func) {
+  template <typename OpTy>
+  void convertSignature(OpTy func) {
     auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
     if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func))
       return;
@@ -805,13 +818,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     // Convert return value(s)
     for (auto ty : funcTy.getResults())
       llvm::TypeSwitch<mlir::Type>(ty)
-          .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+          .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
             if (noComplexConversion)
               newResTys.push_back(cmplx);
             else
               doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups);
           })
-          .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+          .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
             auto m = specifics->integerArgumentType(func.getLoc(), intTy);
             assert(m.size() == 1);
             auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -825,10 +838,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
                                                 rewriter->getUnitAttr()));
             newResTys.push_back(retTy);
           })
-          .Case<fir::RecordType>([&](fir::RecordType recTy) {
+          .template Case<fir::RecordType>([&](fir::RecordType recTy) {
             doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups);
           })
-          .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
+          .template Default([&](mlir::Type ty) { newResTys.push_back(ty); });
 
     // Saved potential shift in argument. Handling of result can add arguments
     // at the beginning of the function signature.
@@ -840,7 +853,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       auto ty = e.value();
       unsigned index = e.index();
       llvm::TypeSwitch<mlir::Type>(ty)
-          .Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
+          .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
             if (noCharacterConversion) {
               newInTyAndAttrs.push_back(
                   fir::CodeGenSpecifics::getTypeAndAttr(boxTy));
@@ -863,10 +876,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
               }
             }
           })
-          .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+          .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
             doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
           })
-          .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
+          .template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
             if (fir::isCharacterProcedureTuple(tuple)) {
               fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
                                   newInTyAndAttrs.size(), trailingTys.size());
@@ -878,7 +891,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
                   fir::CodeGenSpecifics::getTypeAndAttr(ty));
             }
           })
-          .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+          .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
             auto m = specifics->integerArgumentType(func.getLoc(), intTy);
             assert(m.size() == 1);
             auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -887,7 +900,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
             if (!extensionAttrName.empty() &&
                 isFuncWithCCallingConvention(func))
               fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo,
-                                  [=](mlir::func::FuncOp func) {
+                                  [=](OpTy func) {
                                     func.setArgAttr(
                                         argNo, extensionAttrName,
                                         mlir::UnitAttr::get(func.getContext()));
@@ -898,13 +911,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
           .template Case<fir::RecordType>([&](fir::RecordType recTy) {
             doStructArg(func, recTy, newInTyAndAttrs, fixups);
           })
-          .Default([&](mlir::Type ty) {
+          .template Default([&](mlir::Type ty) {
             newInTyAndAttrs.push_back(
                 fir::CodeGenSpecifics::getTypeAndAttr(ty));
           });
 
-      if (func.getArgAttrOfType<mlir::UnitAttr>(index,
-                                                fir::getHostAssocAttrName())) {
+      if (func.template getArgAttrOfType<mlir::UnitAttr>(
+              index, fir::getHostAssocAttrName())) {
         extraAttrs.push_back(
             {newInTyAndAttrs.size() - 1,
              rewriter->getNamedAttr("llvm.nest", rewriter->getUnitAttr())});
@@ -979,29 +992,52 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
           auto newArg =
               func.front().insertArgument(fixup.index, fixupType, loc);
           offset++;
-          func.walk([&](mlir::func::ReturnOp ret) {
-            rewriter->setInsertionPoint(ret);
-            auto oldOper = ret.getOperand(0);
-            auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
-            auto cast =
-                rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
-            rewriter->create<fir::StoreOp>(loc, oldOper, cast);
-            rewriter->create<mlir::func::ReturnOp>(loc);
-            ret.erase();
-          });
+          if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
+            func.walk([&](mlir::func::ReturnOp ret) {
+              rewriter->setInsertionPoint(ret);
+              auto oldOper = ret.getOperand(0);
+              auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
+              auto cast =
+                  rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
+              rewriter->create<fir::StoreOp>(loc, oldOper, cast);
+              rewriter->create<mlir::func::ReturnOp>(loc);
+              ret.erase();
+            });
+          if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
+            func.walk([&](mlir::gpu::ReturnOp ret) {
+              rewriter->setInsertionPoint(ret);
+              auto oldOper = ret.getOperand(0);
+              auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
+              auto cast =
+                  rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
+              rewriter->create<fir::StoreOp>(loc, oldOper, cast);
+              rewriter->create<mlir::gpu::ReturnOp>(loc);
+              ret.erase();
+            });
         } break;
         case FixupTy::Codes::ReturnType: {
           // The function is still returning a value, but its type has likely
           // changed to suit the target ABI convention.
-          func.walk([&](mlir::func::ReturnOp ret) {
-            rewriter->setInsertionPoint(ret);
-            auto oldOper = ret.getOperand(0);
-            mlir::Value bitcast =
-                convertValueInMemory(loc, oldOper, newResTys[fixup.index],
-                                     /*inputMayBeBigger=*/false);
-            rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
-            ret.erase();
-          });
+          if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
+            func.walk([&](mlir::func::ReturnOp ret) {
+              rewriter->setInsertionPoint(ret);
+              auto oldOper = ret.getOperand(0);
+              mlir::Value bitcast =
+                  convertValueInMemory(loc, oldOper, newResTys[fixup.index],
+                                       /*inputMayBeBigger=*/false);
+              rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
+              ret.erase();
+            });
+          if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
+            func.walk([&](mlir::gpu::ReturnOp ret) {
+              rewriter->setInsertionPoint(ret);
+              auto oldOper = ret.getOperand(0);
+              mlir::Value bitcast =
+                  convertValueInMemory(loc, oldOper, newResTys[fixup.index],
+                                       /*inputMayBeBigger=*/false);
+              rewriter->create<mlir::gpu::ReturnOp>(loc, bitcast);
+              ret.erase();
+            });
         } break;
         case FixupTy::Codes::Split: {
           // The FIR argument has been split into a pair of distinct arguments
@@ -1101,13 +1137,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       }
     }
 
-    for (auto &fixup : fixups)
-      if (fixup.finalizer)
-        (*fixup.finalizer)(func);
+    for (auto &fixup : fixups) {
+      if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
+        if (fixup.finalizer)
+          (*fixup.finalizer)(func);
+      if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
+        if (fixup.gpuFinalizer)
+          (*fixup.gpuFinalizer)(func);
+    }
   }
 
-  template <typename Ty, typename FIXUPS>
-  void doReturn(mlir::func::FuncOp func, Ty &newResTys,
+  template <typename OpTy, typename Ty, typename FIXUPS>
+  void doReturn(OpTy func, Ty &newResTys,
                 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
                 FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
     assert(m.size() == 1 &&
@@ -1119,7 +1160,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       unsigned argNo = newInTyAndAttrs.size();
       if (auto align = attr.getAlignment())
         fixups.emplace_back(
-            FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
+            FixupTy::Codes::ReturnAsStore, argNo, [=](OpTy func) {
               auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
                   func.getFunctionType().getInput(argNo));
               func.setArgAttr(argNo, "llvm.sret",
@@ -1130,7 +1171,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
             });
       else
         fixups.emplace_back(FixupTy::Codes::ReturnAsStore, argNo,
-                            [=](mlir::func::FuncOp func) {
+                            [=](OpTy func) {
                               auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
                                   func.getFunctionType().getInput(argNo));
                               func.setArgAttr(argNo, "llvm.sret",
@@ -1141,8 +1182,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     }
     if (auto align = attr.getAlignment())
       fixups.emplace_back(
-          FixupTy::Codes::ReturnType, newResTys.size(),
-          [=](mlir::func::FuncOp func) {
+          FixupTy::Codes::ReturnType, newResTys.size(), [=](OpTy func) {
             func.setArgAttr(
                 newResTys.size(), "llvm.align",
                 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align));
@@ -1155,9 +1195,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
   /// Convert a complex return value. This can involve converting the return
   /// value to a "hidden" first argument or packing the complex into a wide
   /// GPR.
-  template <typename Ty, typename FIXUPS>
-  void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx,
-                       Ty &newResTys,
+  template <typename OpTy, typename Ty, typename FIXUPS>
+  void doComplexReturn(OpTy func, mlir::ComplexType cmplx, Ty &newResTys,
                        fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
                        FIXUPS &fixups) {
     if (noComplexConversion) {
@@ -1169,9 +1208,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
   }
 
-  template <typename Ty, typename FIXUPS>
-  void doStructReturn(mlir::func::FuncOp func, fir::RecordType recTy,
-                      Ty &newResTys,
+  template <typename OpTy, typename Ty, typename FIXUPS>
+  void doStructReturn(OpTy func, fir::RecordType recTy, Ty &newResTys,
                       fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
                       FIXUPS &fixups) {
     if (noStructConversion) {
@@ -1182,12 +1220,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
   }
 
-  template <typename FIXUPS>
-  void
-  createFuncOpArgFixups(mlir::func::FuncOp func,
-                        fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
-                        fir::CodeGenSpecifics::Marshalling &argsInTys,
-                        FIXUPS &fixups) {
+  template <typename OpTy, typename FIXUPS>
+  void createFuncOpArgFixups(
+      OpTy func, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
+      fir::CodeGenSpecifics::Marshalling &argsInTys, FIXUPS &fixups) {
     const auto fixupCode = argsInTys.size() > 1 ? FixupTy::Codes::Split
                                                 : FixupTy::Codes::ArgumentType;
     for (auto e : llvm::enumerate(argsInTys)) {
@@ -1198,7 +1234,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       if (attr.isByVal()) {
         if (auto align = attr.getAlignment())
           fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, argNo,
-                              [=](mlir::func::FuncOp func) {
+                              [=](OpTy func) {
                                 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
                                     func.getFunctionType().getInput(argNo));
                                 func.setArgAttr(argNo, "llvm.byval",
@@ -1210,8 +1246,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
                               });
         else
           fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad,
-                              newInTyAndAttrs.size(),
-                              [=](mlir::func::FuncOp func) {
+                              newInTyAndAttrs.size(), [=](OpTy func) {
                                 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
                                     func.getFunctionType().getInput(argNo));
                                 func.setArgAttr(argNo, "llvm.byval",
@@ -1220,7 +1255,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       } else {
         if (auto align = attr.getAlignment())
           fixups.emplace_back(
-              fixupCode, argNo, index, [=](mlir::func::FuncOp func) {
+              fixupCode, argNo, index, [=](OpTy func) {
                 func.setArgAttr(argNo, "llvm.align",
                                 rewriter->getIntegerAttr(
                                     rewriter->getIntegerType(32), align));
@@ -1235,8 +1270,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
   /// Convert a complex argument value. This can involve storing the value to
   /// a temporary memory location or factoring the value into two distinct
   /// arguments.
-  template <typename FIXUPS>
-  void doComplexArg(mlir::func::FuncOp func, mlir::ComplexType cmplx,
+  template <typename OpTy, typename FIXUPS>
+  void doComplexArg(OpTy func, mlir::ComplexType cmplx,
                     fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
                     FIXUPS &fixups) {
     if (noComplexConversion) {
@@ -1248,8 +1283,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     createFuncOpArgFixups(func, newInTyAndAttrs, cplxArgs, fixups);
   }
 
-  template <typename FIXUPS>
-  void doStructArg(mlir::func::FuncOp func, fir::RecordType recTy,
+  template <typename OpTy, typename FIXUPS>
+  void doStructArg(OpTy func, fir::RecordType recTy,
                    fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
                    FIXUPS &fixups) {
     if (noStructConversion) {
diff --git a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
index c14efc8b13f66b..d88b6776795a0b 100644
--- a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
+++ b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
@@ -1,5 +1,5 @@
 // REQUIRES: x86-registered-target
-// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
+// RUN: fir-opt --split-input-file --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
 
 gpu.module @testmod {
   gpu.func @_QPvcpowdk(%arg0: !fir.ref<complex<f64>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
@@ -15,3 +15,15 @@ gpu.module @testmod {
 // CHECK-LABEL: gpu.func @_QPvcpowdk
 // CHECK: %{{.*}} = fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}, %{{.*}}) : (f64, f64, i64) -> tuple<f64, f64>
 // CHECK: func.func private @_FortranAzpowk(f64, f64, i64) -> tuple<f64, f64> attributes {fir.bindc_name = "_FortranAzpowk", fir.runtime}
+
+// -----
+
+gpu.module @testmod {
+  gpu.func @_QPtest(%arg0: complex<f64>) -> (complex<f64>) {
+    gpu.return %arg0 : complex<f64>
+  }
+}
+
+// CHECK-LABEL: gpu.func @_QPtest
+// CHECK-SAME: (%arg0: f64, %arg1: f64) -> tuple<f64, f64> {
+// CHECK: gpu.return %{{.*}} : tuple<f64, f64>

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small suggestion to avoid duplicating return op handling, LGTM otherwise.

@clementval clementval merged commit dc5236e into llvm:main Dec 10, 2024
8 checks passed
@clementval clementval deleted the cuf_targetrewrite branch December 10, 2024 20:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants