Skip to content

Passing descriptors by reference to CUDA runtime calls #114288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 30, 2024

Conversation

Renaud-K
Copy link
Contributor

Passing a descriptor as a const Descriptor & or a const Descriptor * generates a FIR signature where the box is passed by value.
This is an issue, as it requires a load of the box to be passed. But since, ultimately, all boxes are passed by reference a temporary is generated in LLVM and the reference to the temporary is passed.

The boxes addresses are registered with the CUDA runtime but the temporaries are not, thus preventing the runtime to properly map a host side address to its device side counterpart.

To address this issue, this PR changes the signatures to the transfer functions to pass a descriptor as a Descriptor *, which will in turn generate a FIR signature with that takes a box reference as an argument.

@Renaud-K Renaud-K requested a review from clementval October 30, 2024 18:28
@llvmbot llvmbot added flang:runtime flang Flang issues not falling into any other category flang:fir-hlfir labels Oct 30, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 30, 2024

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-runtime

Author: Renaud Kauffmann (Renaud-K)

Changes

Passing a descriptor as a const Descriptor & or a const Descriptor * generates a FIR signature where the box is passed by value.
This is an issue, as it requires a load of the box to be passed. But since, ultimately, all boxes are passed by reference a temporary is generated in LLVM and the reference to the temporary is passed.

The boxes addresses are registered with the CUDA runtime but the temporaries are not, thus preventing the runtime to properly map a host side address to its device side counterpart.

To address this issue, this PR changes the signatures to the transfer functions to pass a descriptor as a Descriptor *, which will in turn generate a FIR signature with that takes a box reference as an argument.


Full diff: https://github.com/llvm/llvm-project/pull/114288.diff

4 Files Affected:

  • (modified) flang/include/flang/Runtime/CUDA/memory.h (+4-5)
  • (modified) flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (+4-7)
  • (modified) flang/runtime/CUDA/memory.cpp (+4-5)
  • (modified) flang/test/Fir/CUDA/cuda-data-transfer.fir (+11-17)
diff --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h
index 3c3ae73d4ad7a1..fb48152d707182 100644
--- a/flang/include/flang/Runtime/CUDA/memory.h
+++ b/flang/include/flang/Runtime/CUDA/memory.h
@@ -36,19 +36,18 @@ void RTDECL(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
     unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
 
 /// Data transfer from a pointer to a descriptor.
-void RTDECL(CUFDataTransferDescPtr)(const Descriptor &dst, void *src,
+void RTDECL(CUFDataTransferDescPtr)(Descriptor *dst, void *src,
     std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
     int sourceLine = 0);
 
 /// Data transfer from a descriptor to a pointer.
-void RTDECL(CUFDataTransferPtrDesc)(void *dst, const Descriptor &src,
+void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
     std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
     int sourceLine = 0);
 
 /// Data transfer from a descriptor to a descriptor.
-void RTDECL(CUFDataTransferDescDesc)(const Descriptor &dst,
-    const Descriptor &src, unsigned mode, const char *sourceFile = nullptr,
-    int sourceLine = 0);
+void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
+    unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
 
 } // extern "C"
 } // namespace Fortran::runtime::cuda
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index f1f3a95b220df5..e3e441360e949b 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -529,8 +529,8 @@ struct CUFDataTransferOpConversion
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
       mlir::Value sourceLine =
           fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
-      mlir::Value dst = builder.loadIfRef(loc, op.getDst());
-      mlir::Value src = builder.loadIfRef(loc, op.getSrc());
+      mlir::Value dst = op.getDst();
+      mlir::Value src = op.getSrc();
       llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
           builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
       builder.create<fir::CallOp>(loc, func, args);
@@ -603,11 +603,8 @@ struct CUFDataTransferOpConversion
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
       mlir::Value sourceLine =
           fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
-      mlir::Value dst =
-          dstIsDesc ? builder.loadIfRef(loc, op.getDst()) : op.getDst();
-      mlir::Value src = mlir::isa<fir::BaseBoxType>(srcTy)
-                            ? builder.loadIfRef(loc, op.getSrc())
-                            : op.getSrc();
+      mlir::Value dst = op.getDst();
+      mlir::Value src = op.getSrc();
       llvm::SmallVector<mlir::Value> args{
           fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
                                         modeValue, sourceFile, sourceLine)};
diff --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp
index fc48b4343eea9d..4778a4ae77683f 100644
--- a/flang/runtime/CUDA/memory.cpp
+++ b/flang/runtime/CUDA/memory.cpp
@@ -73,23 +73,22 @@ void RTDEF(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
   CUDA_REPORT_IF_ERROR(cudaMemcpy(dst, src, bytes, kind));
 }
 
-void RTDEF(CUFDataTransferDescPtr)(const Descriptor &desc, void *addr,
+void RTDEF(CUFDataTransferDescPtr)(Descriptor *desc, void *addr,
     std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
   Terminator terminator{sourceFile, sourceLine};
   terminator.Crash(
       "not yet implemented: CUDA data transfer from a pointer to a descriptor");
 }
 
-void RTDEF(CUFDataTransferPtrDesc)(void *addr, const Descriptor &desc,
+void RTDEF(CUFDataTransferPtrDesc)(void *addr, Descriptor *desc,
     std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
   Terminator terminator{sourceFile, sourceLine};
   terminator.Crash(
       "not yet implemented: CUDA data transfer from a descriptor to a pointer");
 }
 
-void RTDECL(CUFDataTransferDescDesc)(const Descriptor &dstDesc,
-    const Descriptor &srcDesc, unsigned mode, const char *sourceFile,
-    int sourceLine) {
+void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
+    unsigned mode, const char *sourceFile, int sourceLine) {
   Terminator terminator{sourceFile, sourceLine};
   terminator.Crash(
       "not yet implemented: CUDA data transfer between two descriptors");
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index c33c50115b9fc0..b99e09fb76468b 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -15,11 +15,9 @@ func.func @_QPsub1() {
 // CHECK-LABEL: func.func @_QPsub1()
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
 // CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Eahost"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
-// CHECK: %[[AHOST_LOAD:.*]] = fir.load %[[AHOST]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
-// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[AHOST_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
-// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.box<none>, i32, !fir.ref<i8>, i32) -> none
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
 func.func @_QPsub2() {
   %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub2Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -76,19 +74,17 @@ func.func @_QPsub4() {
 // CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
 // CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
 // CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
 // CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
 // CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
 // CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
 // CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.box<none>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
 
 func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
   %0 = fir.dummy_scope : !fir.dscope
@@ -122,19 +118,17 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
 // CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
 // CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
 // CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>) -> !fir.box<none>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
 // CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
 // CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
 // CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
 // CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>) -> !fir.box<none>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.box<none>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
 
 func.func @_QPsub6() {
   %0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub6Eidev"} -> !fir.ref<i32>

Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@Renaud-K Renaud-K merged commit bfe486f into llvm:main Oct 30, 2024
10 of 11 checks passed
smallp-o-p pushed a commit to smallp-o-p/llvm-project that referenced this pull request Nov 3, 2024
Passing a descriptor as a `const Descriptor &` or a `const Descriptor *`
generates a FIR signature where the box is passed by value.
This is an issue, as it requires a load of the box to be passed. But
since, ultimately, all boxes are passed by reference a temporary is
generated in LLVM and the reference to the temporary is passed.

The boxes addresses are registered with the CUDA runtime but the
temporaries are not, thus preventing the runtime to properly map a host
side address to its device side counterpart.

To address this issue, this PR changes the signatures to the transfer
functions to pass a descriptor as a `Descriptor *`, which will in turn
generate a FIR signature with that takes a box reference as an argument.
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
Passing a descriptor as a `const Descriptor &` or a `const Descriptor *`
generates a FIR signature where the box is passed by value.
This is an issue, as it requires a load of the box to be passed. But
since, ultimately, all boxes are passed by reference a temporary is
generated in LLVM and the reference to the temporary is passed.

The boxes addresses are registered with the CUDA runtime but the
temporaries are not, thus preventing the runtime to properly map a host
side address to its device side counterpart.

To address this issue, this PR changes the signatures to the transfer
functions to pass a descriptor as a `Descriptor *`, which will in turn
generate a FIR signature with that takes a box reference as an argument.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:runtime flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants