Skip to content

[flang][cuda] Convert cuf.alloc and cuf.free for scalar and arrays #110055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 30, 2024

Conversation

clementval
Copy link
Contributor

This patch adds more conversion of cuf.alloc and cuf.free for scalars, constant size arrays and dynamic size arrays

@clementval clementval requested a review from wangzpgi September 25, 2024 22:38
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Sep 25, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 25, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

This patch adds more conversion of cuf.alloc and cuf.free for scalars, constant size arrays and dynamic size arrays


Full diff: https://github.com/llvm/llvm-project/pull/110055.diff

3 Files Affected:

  • (modified) flang/lib/Optimizer/Transforms/CufOpConversion.cpp (+92-49)
  • (added) flang/test/Fir/CUDA/cuda-alloc-free.fir (+64)
  • (modified) flang/test/Fir/CUDA/cuda-allocate.fir (-11)
diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
index f8ace2dd96a0d8..fff026bbda2d40 100644
--- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -183,6 +183,29 @@ static bool inDeviceContext(mlir::Operation *op) {
   return false;
 }
 
+static int computeWidth(mlir::Location loc, mlir::Type type,
+                        fir::KindMapping &kindMap) {
+  auto eleTy = fir::unwrapSequenceType(type);
+  int width = 0;
+  if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
+    width = t.getWidth() / 8;
+  } else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
+    width = t.getWidth() / 8;
+  } else if (eleTy.isInteger(1)) {
+    width = 1;
+  } else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
+    int kind = t.getFKind();
+    width = kindMap.getLogicalBitsize(kind) / 8;
+  } else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
+    int kind = t.getFKind();
+    int elemSize = kindMap.getRealBitsize(kind) / 8;
+    width = 2 * elemSize;
+  } else {
+    llvm::report_fatal_error("unsupported type");
+  }
+  return width;
+}
+
 struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -193,11 +216,6 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
   mlir::LogicalResult
   matchAndRewrite(cuf::AllocOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
-
-    // Only convert cuf.alloc that allocates a descriptor.
-    if (!boxTy)
-      return failure();
 
     if (inDeviceContext(op.getOperation())) {
       // In device context just replace the cuf.alloc operation with a fir.alloc
@@ -212,11 +230,56 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
     auto mod = op->getParentOfType<mlir::ModuleOp>();
     fir::FirOpBuilder builder(rewriter, mod);
     mlir::Location loc = op.getLoc();
+    mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+
+    if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
+      // Convert scalar and known size array allocations.
+      mlir::Value bytes;
+      fir::KindMapping kindMap{fir::getKindMapping(mod)};
+      if (fir::isa_trivial(op.getInType())) {
+        int width = computeWidth(loc, op.getInType(), kindMap);
+        bytes =
+            builder.createIntegerConstant(loc, builder.getIndexType(), width);
+      } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
+                     op.getInType())) {
+        mlir::Value width = builder.createIntegerConstant(
+            loc, builder.getIndexType(),
+            computeWidth(loc, seqTy.getEleTy(), kindMap));
+        mlir::Value nbElem;
+        if (fir::sequenceWithNonConstantShape(seqTy)) {
+          assert(!op.getShape().empty() && "expect shape with dynamic arrays");
+          nbElem = builder.loadIfRef(loc, op.getShape()[0]);
+          for (unsigned i = 1; i < op.getShape().size(); ++i) {
+            nbElem = rewriter.create<mlir::arith::MulIOp>(
+                loc, nbElem, builder.loadIfRef(loc, op.getShape()[i]));
+          }
+        } else {
+          nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
+                                                 seqTy.getConstantArraySize());
+        }
+        bytes = rewriter.create<mlir::arith::MulIOp>(loc, nbElem, width);
+      }
+      mlir::func::FuncOp func =
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
+      auto fTy = func.getFunctionType();
+      mlir::Value sourceLine =
+          fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+      mlir::Value memTy = builder.createIntegerConstant(
+          loc, builder.getI32Type(), kMemTypeDevice);
+      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+          builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
+      auto callOp = builder.create<fir::CallOp>(loc, func, args);
+      auto convOp = builder.createConvert(loc, op.getResult().getType(),
+                                          callOp.getResult(0));
+      rewriter.replaceOp(op, convOp);
+      return mlir::success();
+    }
+
+    // Convert descriptor allocations to function call.
+    auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
     mlir::func::FuncOp func =
         fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDesciptor)>(loc, builder);
-
     auto fTy = func.getFunctionType();
-    mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
     mlir::Value sourceLine =
         fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
 
@@ -245,26 +308,39 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
   mlir::LogicalResult
   matchAndRewrite(cuf::FreeOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    // Only convert cuf.free on descriptor.
-    if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
-      return failure();
-    auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
-    if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy()))
-      return failure();
-
     if (inDeviceContext(op.getOperation())) {
       rewriter.eraseOp(op);
       return mlir::success();
     }
 
+    if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
+      return failure();
+
     auto mod = op->getParentOfType<mlir::ModuleOp>();
     fir::FirOpBuilder builder(rewriter, mod);
     mlir::Location loc = op.getLoc();
+    mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+
+    auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
+    if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
+      mlir::func::FuncOp func =
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
+      auto fTy = func.getFunctionType();
+      mlir::Value sourceLine =
+          fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+      mlir::Value memTy = builder.createIntegerConstant(
+          loc, builder.getI32Type(), kMemTypeDevice);
+      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+          builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
+      builder.create<fir::CallOp>(loc, func, args);
+      rewriter.eraseOp(op);
+      return mlir::success();
+    }
+
+    // Convert cuf.free on descriptors.
     mlir::func::FuncOp func =
         fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDesciptor)>(loc, builder);
-
     auto fTy = func.getFunctionType();
-    mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
     mlir::Value sourceLine =
         fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
     llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
@@ -275,29 +351,6 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
   }
 };
 
-static int computeWidth(mlir::Location loc, mlir::Type type,
-                        fir::KindMapping &kindMap) {
-  auto eleTy = fir::unwrapSequenceType(type);
-  int width = 0;
-  if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
-    width = t.getWidth() / 8;
-  } else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
-    width = t.getWidth() / 8;
-  } else if (eleTy.isInteger(1)) {
-    width = 1;
-  } else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
-    int kind = t.getFKind();
-    width = kindMap.getLogicalBitsize(kind) / 8;
-  } else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
-    int kind = t.getFKind();
-    int elemSize = kindMap.getRealBitsize(kind) / 8;
-    width = 2 * elemSize;
-  } else {
-    llvm::report_fatal_error("unsupported type");
-  }
-  return width;
-}
-
 static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
                                    mlir::Location loc, mlir::Type toTy,
                                    mlir::Value val) {
@@ -456,16 +509,6 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
         fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
     fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
                                          /*forceUnifiedTBAATree=*/false, *dl);
-    target.addDynamicallyLegalOp<cuf::AllocOp>([](::cuf::AllocOp op) {
-      return !mlir::isa<fir::BaseBoxType>(op.getInType());
-    });
-    target.addDynamicallyLegalOp<cuf::FreeOp>([](::cuf::FreeOp op) {
-      if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(
-              op.getDevptr().getType())) {
-        return !mlir::isa<fir::BaseBoxType>(refTy.getEleTy());
-      }
-      return true;
-    });
     target.addDynamicallyLegalOp<cuf::DataTransferOp>(
         [](::cuf::DataTransferOp op) {
           mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
diff --git a/flang/test/Fir/CUDA/cuda-alloc-free.fir b/flang/test/Fir/CUDA/cuda-alloc-free.fir
new file mode 100644
index 00000000000000..25821418a40f11
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-alloc-free.fir
@@ -0,0 +1,64 @@
+// RUN: fir-opt --cuf-convert %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
+
+func.func @_QPsub1() {
+  %0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} -> !fir.ref<i32>
+  %1:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  cuf.free %1#1 : !fir.ref<i32> {data_attr = #cuf.cuda<device>}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub1()
+// CHECK: %[[BYTES:.*]] = fir.convert %c4{{.*}} : (index) -> i64
+// CHECK: %[[ALLOC:.*]] = fir.call @_FortranACUFMemAlloc(%[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[CONV:.*]] = fir.convert %3 : (!fir.llvm_ptr<i8>) -> !fir.ref<i32>
+// CHECK: %[[DECL:.*]]:2 = hlfir.declare %[[CONV]] {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[DEVPTR:.*]] = fir.convert %[[DECL]]#1 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFMemFree(%[[DEVPTR]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, i32, !fir.ref<i8>, i32) -> none
+
+func.func @_QPsub2() {
+  %0 = cuf.alloc !fir.array<10xf32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref<!fir.array<10xf32>>
+  cuf.free %0 : !fir.ref<!fir.array<10xf32>> {data_attr = #cuf.cuda<device>}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub2()
+// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c4{{.*}} : index
+// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64 
+// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFMemFree
+
+func.func @_QPsub3(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<i32> {fir.bindc_name = "m"}) {
+  %0 = fir.dummy_scope : !fir.dscope
+  %1:2 = hlfir.declare %arg0 dummy_scope %0 {uniq_name = "_QFsub3En"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %2:2 = hlfir.declare %arg1 dummy_scope %0 {uniq_name = "_QFsub3Em"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.load %1#0 : !fir.ref<i32>
+  %4 = fir.convert %3 : (i32) -> i64
+  %5 = fir.convert %4 : (i64) -> index
+  %c0 = arith.constant 0 : index
+  %6 = arith.cmpi sgt, %5, %c0 : index
+  %7 = arith.select %6, %5, %c0 : index
+  %8 = fir.load %2#0 : !fir.ref<i32>
+  %9 = fir.convert %8 : (i32) -> i64
+  %10 = fir.convert %9 : (i64) -> index
+  %c0_0 = arith.constant 0 : index
+  %11 = arith.cmpi sgt, %10, %c0_0 : index
+  %12 = arith.select %11, %10, %c0_0 : index
+  %13 = cuf.alloc !fir.array<?x?xi32>, %7, %12 : index, index {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eidev"} -> !fir.ref<!fir.array<?x?xi32>>
+  %14 = fir.shape %7, %12 : (index, index) -> !fir.shape<2>
+  %15:2 = hlfir.declare %13(%14) {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eidev"} : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.ref<!fir.array<?x?xi32>>)
+  cuf.free %15#1 : !fir.ref<!fir.array<?x?xi32>> {data_attr = #cuf.cuda<device>}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub3
+// CHECK: %[[N:.*]] = arith.select 
+// CHECK: %[[M:.*]] = arith.select
+// CHECK: %[[NBELEM:.*]] = arith.muli %[[N]], %[[M]] : index
+// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %c4{{.*}} : index
+// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64
+// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFMemFree
+
+} // end module
diff --git a/flang/test/Fir/CUDA/cuda-allocate.fir b/flang/test/Fir/CUDA/cuda-allocate.fir
index 65c68bb69301af..d68ff894d5af5a 100644
--- a/flang/test/Fir/CUDA/cuda-allocate.fir
+++ b/flang/test/Fir/CUDA/cuda-allocate.fir
@@ -26,17 +26,6 @@ func.func @_QPsub1() {
 // CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: fir.call @_FortranACUFFreeDesciptor(%[[BOX_NONE]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<i8>, i32) -> none
 
-// Check operations that should not be transformed yet.
-func.func @_QPsub2() {
-  %0 = cuf.alloc !fir.array<10xf32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref<!fir.array<10xf32>>
-  cuf.free %0 : !fir.ref<!fir.array<10xf32>> {data_attr = #cuf.cuda<device>}
-  return
-}
-
-// CHECK-LABEL: func.func @_QPsub2()
-// CHECK: cuf.alloc !fir.array<10xf32>
-// CHECK: cuf.free %{{.*}} : !fir.ref<!fir.array<10xf32>>
-
 fir.global @_QMmod1Ea {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xf32>>> {
     %0 = fir.zero_bits !fir.heap<!fir.array<?xf32>>
     %c0 = arith.constant 0 : index

@clementval clementval merged commit 39e254e into llvm:main Sep 30, 2024
8 checks passed
@clementval clementval deleted the cuf_op_conversion_alloc branch September 30, 2024 16:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants