Skip to content

[flang][cuda] Convert cuf.data_transfer with descriptors #108890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 17, 2024

Conversation

clementval
Copy link
Contributor

Convert cuf.data_transfer operations involving descriptors to the newly introduced entry points (#108244).

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Sep 16, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 16, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Convert cuf.data_transfer operations involving descriptors to the newly introduced entry points (#108244).


Patch is 23.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108890.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/Transforms/CufOpConversion.cpp (+175-2)
  • (added) flang/test/Fir/CUDA/cuda-data-transfer.fir (+140)
diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
index c22c74d3f78af7..03a1eb74343b43 100644
--- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -15,6 +15,7 @@
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/Support/DataLayout.h"
 #include "flang/Runtime/CUDA/descriptor.h"
+#include "flang/Runtime/CUDA/memory.h"
 #include "flang/Runtime/allocatable.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -255,6 +256,171 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
   }
 };
 
+static int computeWidth(mlir::Location loc, mlir::Type type,
+                        fir::KindMapping &kindMap) {
+  auto eleTy = fir::unwrapSequenceType(type);
+  int width = 0;
+  if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
+    width = t.getWidth() / 8;
+  } else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
+    width = t.getWidth() / 8;
+  } else if (eleTy.isInteger(1)) {
+    width = 1;
+  } else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
+    int kind = t.getFKind();
+    width = kindMap.getLogicalBitsize(kind) / 8;
+  } else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
+    int kind = t.getFKind();
+    int elemSize = kindMap.getRealBitsize(kind) / 8;
+    width = 2 * elemSize;
+  } else {
+    llvm::report_fatal_error("unsupported type");
+  }
+  return width;
+}
+
+static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
+                                   mlir::Location loc, mlir::Type toTy,
+                                   mlir::Value val) {
+  if (val.getType() != toTy)
+    return rewriter.create<fir::ConvertOp>(loc, toTy, val);
+  return val;
+}
+
+struct CufDataTransferOpConversion
+    : public mlir::OpRewritePattern<cuf::DataTransferOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cuf::DataTransferOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+
+    mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
+    mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
+
+    // Only convert cuf.data_transfer with at least one descripor.
+    if (!mlir::isa<fir::BaseBoxType>(srcTy) &&
+        !mlir::isa<fir::BaseBoxType>(dstTy))
+      return failure();
+
+    unsigned mode;
+    if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) {
+      mode = kHostToDevice;
+    } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceHost) {
+      mode = kDeviceToHost;
+    } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceDevice) {
+      mode = kDeviceToDevice;
+    }
+
+    auto mod = op->getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, mod);
+    mlir::Location loc = op.getLoc();
+
+    if (mlir::isa<fir::BaseBoxType>(srcTy) &&
+        mlir::isa<fir::BaseBoxType>(dstTy)) {
+      // Transfer between two descriptor.
+      mlir::func::FuncOp func =
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
+              loc, builder);
+
+      auto fTy = func.getFunctionType();
+      mlir::Value modeValue =
+          builder.createIntegerConstant(loc, builder.getI32Type(), mode);
+      mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+      mlir::Value sourceLine =
+          fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+      mlir::Value dst = builder.loadIfRef(loc, op.getDst());
+      mlir::Value src = builder.loadIfRef(loc, op.getSrc());
+      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+          builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
+      builder.create<fir::CallOp>(loc, func, args);
+      rewriter.eraseOp(op);
+    } else if (mlir::isa<fir::BaseBoxType>(dstTy) && fir::isa_trivial(srcTy)) {
+      // Scalar to descriptor transfer.
+      mlir::Value val = op.getSrc();
+      if (op.getSrc().getDefiningOp() &&
+          mlir::isa<mlir::arith::ConstantOp>(op.getSrc().getDefiningOp())) {
+        mlir::Value alloc = builder.createTemporary(loc, srcTy);
+        builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
+        val = alloc;
+      }
+
+      mlir::func::FuncOp func =
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFMemsetDescriptor)>(loc,
+                                                                     builder);
+      auto fTy = func.getFunctionType();
+      mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+      mlir::Value sourceLine =
+          fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+      mlir::Value dst = builder.loadIfRef(loc, op.getDst());
+      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+          builder, loc, fTy, dst, val, sourceFile, sourceLine)};
+      builder.create<fir::CallOp>(loc, func, args);
+      rewriter.eraseOp(op);
+    } else {
+      mlir::Value modeValue =
+          builder.createIntegerConstant(loc, builder.getI32Type(), mode);
+      // Type used to compute the width.
+      mlir::Type computeType = dstTy;
+      auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
+      bool dstIsDesc = false;
+      if (mlir::isa<fir::BaseBoxType>(dstTy)) {
+        dstIsDesc = true;
+        computeType = srcTy;
+        seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
+      }
+      fir::KindMapping kindMap{fir::getKindMapping(mod)};
+      int width = computeWidth(loc, computeType, kindMap);
+
+      mlir::Value nbElement;
+      mlir::Type idxTy = rewriter.getIndexType();
+      if (!op.getShape()) {
+        nbElement = rewriter.create<mlir::arith::ConstantOp>(
+            loc, idxTy,
+            rewriter.getIntegerAttr(idxTy, seqTy.getConstantArraySize()));
+      } else {
+        auto shapeOp =
+            mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp());
+        nbElement =
+            createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[0]);
+        for (unsigned i = 1; i < shapeOp.getExtents().size(); ++i) {
+          auto operand =
+              createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[i]);
+          nbElement =
+              rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand);
+        }
+      }
+
+      mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>(
+          loc, idxTy, rewriter.getIntegerAttr(idxTy, width));
+      mlir::Value bytes =
+          rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue);
+
+      mlir::func::FuncOp func =
+          dstIsDesc
+              ? fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescPtr)>(
+                    loc, builder)
+              : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
+                    loc, builder);
+      auto fTy = func.getFunctionType();
+      mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+      mlir::Value sourceLine =
+          fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
+      mlir::Value dst =
+          dstIsDesc ? builder.loadIfRef(loc, op.getDst()) : op.getDst();
+      mlir::Value src = mlir::isa<fir::BaseBoxType>(srcTy)
+                            ? builder.loadIfRef(loc, op.getSrc())
+                            : op.getSrc();
+      llvm::SmallVector<mlir::Value> args{
+          fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
+                                        modeValue, sourceFile, sourceLine)};
+      builder.create<fir::CallOp>(loc, func, args);
+      rewriter.eraseOp(op);
+    }
+    return mlir::success();
+  }
+};
+
 class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
 public:
   void runOnOperation() override {
@@ -285,10 +451,17 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
         [](::cuf::AllocateOp op) { return needDoubleDescriptor(op); });
     target.addDynamicallyLegalOp<cuf::DeallocateOp>(
         [](::cuf::DeallocateOp op) { return needDoubleDescriptor(op); });
-    target.addLegalDialect<fir::FIROpsDialect>();
+    target.addDynamicallyLegalOp<cuf::DataTransferOp>(
+        [](::cuf::DataTransferOp op) {
+          mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
+          mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
+          return !mlir::isa<fir::BaseBoxType>(srcTy) &&
+                 !mlir::isa<fir::BaseBoxType>(dstTy);
+        });
+    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect>();
     patterns.insert<CufAllocOpConversion>(ctx, &*dl, &typeConverter);
     patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion,
-                    CufFreeOpConversion>(ctx);
+                    CufFreeOpConversion, CufDataTransferOpConversion>(ctx);
     if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
                                                   std::move(patterns)))) {
       mlir::emitError(mlir::UnknownLoc::get(ctx),
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
new file mode 100644
index 00000000000000..f639a6c22b76d0
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -0,0 +1,140 @@
+// RUN: fir-opt --cuf-convert %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
+
+func.func @_QPsub1() {
+  %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %4:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+  %5 = fir.alloca !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "ahost", uniq_name = "_QFsub1Eahost"}
+  %9:2 = hlfir.declare %5 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Eahost"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+  cuf.data_transfer %4#0 to %9#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  cuf.free %4#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub1()
+// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Eahost"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+// CHECK: %[[AHOST_LOAD:.*]] = fir.load %[[AHOST]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[AHOST_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.box<none>, i32, !fir.ref<i8>, i32) -> none
+
+func.func @_QPsub2() {
+  %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub2Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %4:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+  %c2_i32 = arith.constant 2 : i32
+  cuf.data_transfer %c2_i32 to %4#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  cuf.free %4#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub2()
+// CHECK: %[[TEMP:.*]] = fir.alloca i32
+// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+// CHECK: %[[C2:.*]] = arith.constant 2 : i32
+// CHECK: fir.store %[[C2]] to %[[TEMP]] : !fir.ref<i32>
+// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP]] : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[TEMP_CONV]], %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
+
+func.func @_QPsub3() {
+  %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %4:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub3Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+  %5 = fir.alloca i32 {bindc_name = "v", uniq_name = "_QFsub3Ev"}
+  %6:2 = hlfir.declare %5 {uniq_name = "_QFsub3Ev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  cuf.data_transfer %6#0 to %4#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<i32>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  cuf.free %4#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub3()
+// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub3Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+// CHECK: %[[V:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub3Ev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: %[[V_CONV:.*]] = fir.convert %[[V]]#0 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[V_CONV]], %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
+
+func.func @_QPsub4() {
+  %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub4Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %4:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+  %c10 = arith.constant 10 : index
+  %5 = fir.alloca !fir.array<10xi32> {bindc_name = "ahost", uniq_name = "_QFsub4Eahost"}
+  %6 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %7:2 = hlfir.declare %5(%6) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+  cuf.data_transfer %7#0 to %4#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  cuf.data_transfer %4#0 to %7#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.array<10xi32>>
+  cuf.free %4#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub4()
+// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
+// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
+// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
+// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
+// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
+// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
+// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
+// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
+// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.box<none>, i64, i32, !fir.ref<i8>, i32) -> none
+
+func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
+  %0 = fir.dummy_scope : !fir.dscope
+  %1 = cuf.alloc !fir.box<!fir.heap<!fir.array<?x?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub5Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
+  %5:2 = hlfir.declare %1 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub5Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
+  %6:2 = hlfir.declare %arg0 dummy_scope %0 {uniq_name = "_...
[truncated]

@clementval clementval merged commit 0bbebf6 into llvm:main Sep 17, 2024
11 checks passed
@clementval clementval deleted the cuf_data_transfer_conversion_desc branch September 17, 2024 18:00
hamphet pushed a commit to hamphet/llvm-project that referenced this pull request Sep 18, 2024
Convert cuf.data_transfer operations involving descriptors to the newly
introduced entry points (llvm#108244).
tmsri pushed a commit to tmsri/llvm-project that referenced this pull request Sep 19, 2024
Convert cuf.data_transfer operations involving descriptors to the newly
introduced entry points (llvm#108244).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants