Skip to content

[flang][cuda] Pass descriptor by reference for CUFMemsetDescriptor #114338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 31, 2024

Conversation

clementval
Copy link
Contributor

Similar to #114288

@llvmbot llvmbot added flang:runtime flang Flang issues not falling into any other category flang:fir-hlfir labels Oct 31, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 31, 2024

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-runtime

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Similar to #114288


Full diff: https://github.com/llvm/llvm-project/pull/114338.diff

4 Files Affected:

  • (modified) flang/include/flang/Runtime/CUDA/memory.h (+1-1)
  • (modified) flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (+1-2)
  • (modified) flang/runtime/CUDA/memory.cpp (+2-2)
  • (modified) flang/test/Fir/CUDA/cuda-data-transfer.fir (+4-6)
diff --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h
index fb48152d707182..6d2e0c0f15942b 100644
--- a/flang/include/flang/Runtime/CUDA/memory.h
+++ b/flang/include/flang/Runtime/CUDA/memory.h
@@ -28,7 +28,7 @@ void RTDECL(CUFMemFree)(void *devicePtr, unsigned type,
 /// Set value to the data hold by a descriptor. The \p value pointer must be
 /// addressable to the same amount of bytes specified by the element size of
 /// the descriptor \p desc.
-void RTDECL(CUFMemsetDescriptor)(const Descriptor &desc, void *value,
+void RTDECL(CUFMemsetDescriptor)(Descriptor *desc, void *value,
     const char *sourceFile = nullptr, int sourceLine = 0);
 
 /// Data transfer from a pointer to a pointer.
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index e3e441360e949b..4050064ebe95d8 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -552,9 +552,8 @@ struct CUFDataTransferOpConversion
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
       mlir::Value sourceLine =
           fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
-      mlir::Value dst = builder.loadIfRef(loc, op.getDst());
       llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
-          builder, loc, fTy, dst, val, sourceFile, sourceLine)};
+          builder, loc, fTy, op.getDst(), val, sourceFile, sourceLine)};
       builder.create<fir::CallOp>(loc, func, args);
       rewriter.eraseOp(op);
     } else {
diff --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp
index 4778a4ae77683f..d03f1cc0e48d66 100644
--- a/flang/runtime/CUDA/memory.cpp
+++ b/flang/runtime/CUDA/memory.cpp
@@ -49,8 +49,8 @@ void RTDEF(CUFMemFree)(
   }
 }
 
-void RTDEF(CUFMemsetDescriptor)(const Descriptor &desc, void *value,
-    const char *sourceFile, int sourceLine) {
+void RTDEF(CUFMemsetDescriptor)(
+    Descriptor *desc, void *value, const char *sourceFile, int sourceLine) {
   Terminator terminator{sourceFile, sourceLine};
   terminator.Crash("not yet implemented: CUDA data transfer from a scalar "
                    "value to a descriptor");
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index b99e09fb76468b..cee3048e279cc7 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -33,10 +33,9 @@ func.func @_QPsub2() {
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
 // CHECK: %[[C2:.*]] = arith.constant 2 : i32
 // CHECK: fir.store %[[C2]] to %[[TEMP]] : !fir.ref<i32>
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP]] : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
-// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[TEMP_CONV]], %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[TEMP_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
 
 func.func @_QPsub3() {
   %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -51,10 +50,9 @@ func.func @_QPsub3() {
 // CHECK-LABEL: func.func @_QPsub3()
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub3Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
 // CHECK: %[[V:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub3Ev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: %[[V_CONV:.*]] = fir.convert %[[V]]#0 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
-// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[V_CONV]], %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[V_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
 
 func.func @_QPsub4() {
   %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub4Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>

@clementval clementval merged commit e4e9fea into llvm:main Oct 31, 2024
12 checks passed
smallp-o-p pushed a commit to smallp-o-p/llvm-project that referenced this pull request Nov 3, 2024
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:runtime flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants