Skip to content

[flang][cuda] Avoid assign element mismatch when doing data transfer from a constant #128252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 22, 2025

Conversation

clementval
Copy link
Contributor

Currently when we do a CUDA data transfer from a constant, we embox it and delegate the assignment to the runtime. When the type of the constant is not exactly the same as the destination descriptor, the runtime will emit an assignment mismatch error.

Convert the constant when necessary so the assignment is fine.

@clementval clementval requested a review from wangzpgi February 21, 2025 23:55
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Feb 21, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Currently when we do a CUDA data transfer from a constant, we embox it and delegate the assignment to the runtime. When the type of the constant is not exactly the same as the destination descriptor, the runtime will emit an assignment mismatch error.

Convert the constant when necessary so the assignment is fine.


Full diff: https://github.com/llvm/llvm-project/pull/128252.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (+19-7)
  • (modified) flang/test/Fir/CUDA/cuda-data-transfer.fir (+22)
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 1f0576aa82f83..2ab2d84f1643d 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -541,7 +541,8 @@ static mlir::Value getShapeFromDecl(mlir::Value src) {
 
 static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
                             cuf::DataTransferOp op,
-                            const mlir::SymbolTable &symtab) {
+                            const mlir::SymbolTable &symtab,
+                            mlir::Type dstEleTy = nullptr) {
   auto mod = op->getParentOfType<mlir::ModuleOp>();
   mlir::Location loc = op.getLoc();
   fir::FirOpBuilder builder(rewriter, mod);
@@ -555,11 +556,21 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
       // from a LOGICAL constant. Store it as a fir.logical.
       srcTy = fir::LogicalType::get(rewriter.getContext(), 4);
       src = createConvertOp(rewriter, loc, srcTy, src);
+      addr = builder.createTemporary(loc, srcTy);
+      builder.create<fir::StoreOp>(loc, src, addr);
+    } else {
+      if (dstEleTy && fir::isa_trivial(dstEleTy) && srcTy != dstEleTy) {
+        // Use dstEleTy and convert to avoid assign mismatch.
+        addr = builder.createTemporary(loc, dstEleTy);
+        auto conv = builder.create<fir::ConvertOp>(loc, dstEleTy, src);
+        builder.create<fir::StoreOp>(loc, conv, addr);
+        srcTy = dstEleTy;
+      } else {
+        // Put constant in memory if it is not.
+        addr = builder.createTemporary(loc, srcTy);
+        builder.create<fir::StoreOp>(loc, src, addr);
+      }
     }
-    // Put constant in memory if it is not.
-    mlir::Value alloc = builder.createTemporary(loc, srcTy);
-    builder.create<fir::StoreOp>(loc, src, alloc);
-    addr = alloc;
   } else {
     addr = op.getSrc();
   }
@@ -729,7 +740,7 @@ struct CUFDataTransferOpConversion
     };
 
     // Conversion of data transfer involving at least one descriptor.
-    if (mlir::isa<fir::BaseBoxType>(dstTy)) {
+    if (auto dstBoxTy = mlir::dyn_cast<fir::BaseBoxType>(dstTy)) {
       // Transfer to a descriptor.
       mlir::func::FuncOp func =
           isDstGlobal(op)
@@ -740,7 +751,8 @@ struct CUFDataTransferOpConversion
       mlir::Value dst = op.getDst();
       mlir::Value src = op.getSrc();
       if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
-        src = emboxSrc(rewriter, op, symtab);
+        mlir::Type dstEleTy = fir::unwrapInnerType(dstBoxTy.getEleTy());
+        src = emboxSrc(rewriter, op, symtab, dstEleTy);
         if (fir::isa_trivial(srcTy))
           func = fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
               loc, builder);
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index b62c500f4a2d3..a724d9f681fb6 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -582,4 +582,26 @@ func.func @_QPchecksums(%arg0: !fir.box<!fir.array<?xf64>> {cuf.data_attr = #cuf
 // CHECK: %[[SRC:.*]] = fir.convert %{{.*}} : (!fir.ref<!fir.box<!fir.array<?xf64>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
 
+func.func @_QPsub20() {
+  %0 = cuf.alloc !fir.box<!fir.heap<f32>> {bindc_name = "r", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub20Er"} -> !fir.ref<!fir.box<!fir.heap<f32>>>
+  %1 = fir.zero_bits !fir.heap<f32>
+  %2 = fir.embox %1 {allocator_idx = 2 : i32} : (!fir.heap<f32>) -> !fir.box<!fir.heap<f32>>
+  fir.store %2 to %0 : !fir.ref<!fir.box<!fir.heap<f32>>>
+  %3:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub20Er"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
+  %c0_i32 = arith.constant 0 : i32
+  cuf.data_transfer %c0_i32 to %3#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<!fir.box<!fir.heap<f32>>>
+  return
+}
+
+// CHECK-LABEL:func.func @_QPsub20
+// CHECK: %[[BOX_ALLOCA:.*]] = fir.alloca !fir.box<f32>
+// CHECK: %[[TMP:.*]] = fir.alloca f32
+// CHECK: %[[CONV:.*]] = fir.convert %c0{{.*}} : (i32) -> f32
+// CHECK: fir.store %[[CONV]] to %[[TMP]] : !fir.ref<f32>
+// CHECK: %[[BOX:.*]] = fir.embox %[[TMP]] : (!fir.ref<f32>) -> !fir.box<f32>
+// CHECK: fir.store %[[BOX]] to %[[BOX_ALLOCA]] : !fir.ref<!fir.box<f32>>
+// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[BOX_ALLOCA]] : (!fir.ref<!fir.box<f32>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%13, %[[BOX_NONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
+
 } // end of module
+

@clementval clementval merged commit 93b2e47 into llvm:main Feb 22, 2025
14 checks passed
@clementval clementval deleted the cuf_cst_data_transfer branch February 22, 2025 01:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants