Skip to content

[flang][cuda] Support scalar to array data transfer #115273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 7, 2024

Conversation

clementval
Copy link
Contributor

Do it via descriptor assignment until we have a more efficient way.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Nov 7, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Do it via descriptor assignment until we have a more efficient way.


Full diff: https://github.com/llvm/llvm-project/pull/115273.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (+67-38)
  • (modified) flang/test/Fir/CUDA/cuda-data-transfer.fir (+14)
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 6187ca03d2c411..881f54133ce732 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -448,6 +448,53 @@ static mlir::Value getShapeFromDecl(mlir::Value src) {
   return mlir::Value{};
 }
 
+static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
+                            cuf::DataTransferOp op,
+                            const mlir::SymbolTable &symtab) {
+  auto mod = op->getParentOfType<mlir::ModuleOp>();
+  mlir::Location loc = op.getLoc();
+  fir::FirOpBuilder builder(rewriter, mod);
+  mlir::Value addr;
+  mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
+  if (fir::isa_trivial(srcTy) &&
+      mlir::matchPattern(op.getSrc().getDefiningOp(), mlir::m_Constant())) {
+    // Put constant in memory if it is not.
+    mlir::Value alloc = builder.createTemporary(loc, srcTy);
+    builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
+    addr = alloc;
+  } else {
+    addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
+  }
+  llvm::SmallVector<mlir::Value> lenParams;
+  mlir::Type boxTy = fir::BoxType::get(srcTy);
+  mlir::Value box =
+      builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getSrc()),
+                        /*slice=*/nullptr, lenParams,
+                        /*tdesc=*/nullptr);
+  mlir::Value src = builder.createTemporary(loc, box.getType());
+  builder.create<fir::StoreOp>(loc, box, src);
+  return src;
+}
+
+static mlir::Value emboxDst(mlir::PatternRewriter &rewriter,
+                            cuf::DataTransferOp op,
+                            const mlir::SymbolTable &symtab) {
+  auto mod = op->getParentOfType<mlir::ModuleOp>();
+  mlir::Location loc = op.getLoc();
+  fir::FirOpBuilder builder(rewriter, mod);
+  mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
+  mlir::Value dstAddr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
+  mlir::Type dstBoxTy = fir::BoxType::get(dstTy);
+  llvm::SmallVector<mlir::Value> lenParams;
+  mlir::Value dstBox =
+      builder.createBox(loc, dstBoxTy, dstAddr, getShapeFromDecl(op.getDst()),
+                        /*slice=*/nullptr, lenParams,
+                        /*tdesc=*/nullptr);
+  mlir::Value dst = builder.createTemporary(loc, dstBox.getType());
+  builder.create<fir::StoreOp>(loc, dstBox, dst);
+  return dst;
+}
+
 struct CUFDataTransferOpConversion
     : public mlir::OpRewritePattern<cuf::DataTransferOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -486,10 +533,22 @@ struct CUFDataTransferOpConversion
         !mlir::isa<fir::BaseBoxType>(dstTy)) {
 
       if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) {
-        // TODO: scalar to array data transfer.
-        mlir::emitError(loc,
-                        "not yet implemented: scalar to array data transfer\n");
-        return mlir::failure();
+        // Initialization of an array from a scalar value should be implemented
+        // via a kernel launch. Use the flan runtime via the Assign function
+        // until we have more infrastructure.
+        mlir::Value src = emboxSrc(rewriter, op, symtab);
+        mlir::Value dst = emboxDst(rewriter, op, symtab);
+        mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
+            CUFDataTransferDescDescNoRealloc)>(loc, builder);
+        auto fTy = func.getFunctionType();
+        mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+        mlir::Value sourceLine =
+            fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+        llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+            builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
+        builder.create<fir::CallOp>(loc, func, args);
+        rewriter.eraseOp(op);
+        return mlir::success();
       }
 
       mlir::Type i64Ty = builder.getI64Type();
@@ -548,29 +607,8 @@ struct CUFDataTransferOpConversion
       mlir::Value dst = op.getDst();
       mlir::Value src = op.getSrc();
 
-      if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
-        // If src is not a descriptor, create one.
-        mlir::Value addr;
-        if (fir::isa_trivial(srcTy) &&
-            mlir::matchPattern(op.getSrc().getDefiningOp(),
-                               mlir::m_Constant())) {
-          // Put constant in memory if it is not.
-          mlir::Value alloc = builder.createTemporary(loc, srcTy);
-          builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
-          addr = alloc;
-        } else {
-          addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
-        }
-        mlir::Type boxTy = fir::BoxType::get(srcTy);
-        llvm::SmallVector<mlir::Value> lenParams;
-        mlir::Value box =
-            builder.createBox(loc, boxTy, addr, getShapeFromDecl(src),
-                              /*slice=*/nullptr, lenParams,
-                              /*tdesc=*/nullptr);
-        mlir::Value memBox = builder.createTemporary(loc, box.getType());
-        builder.create<fir::StoreOp>(loc, box, memBox);
-        src = memBox;
-      }
+      if (!mlir::isa<fir::BaseBoxType>(srcTy))
+        src = emboxSrc(rewriter, op, symtab);
 
       auto fTy = func.getFunctionType();
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
@@ -582,16 +620,7 @@ struct CUFDataTransferOpConversion
       rewriter.eraseOp(op);
     } else {
       // Transfer from a descriptor.
-
-      mlir::Value addr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
-      mlir::Type boxTy = fir::BoxType::get(dstTy);
-      llvm::SmallVector<mlir::Value> lenParams;
-      mlir::Value box =
-          builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getDst()),
-                            /*slice=*/nullptr, lenParams,
-                            /*tdesc=*/nullptr);
-      mlir::Value memBox = builder.createTemporary(loc, box.getType());
-      builder.create<fir::StoreOp>(loc, box, memBox);
+      mlir::Value dst = emboxDst(rewriter, op, symtab);
 
       mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
           CUFDataTransferDescDescNoRealloc)>(loc, builder);
@@ -601,7 +630,7 @@ struct CUFDataTransferOpConversion
       mlir::Value sourceLine =
           fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
       llvm::SmallVector<mlir::Value> args{
-          fir::runtime::createArguments(builder, loc, fTy, memBox, op.getSrc(),
+          fir::runtime::createArguments(builder, loc, fTy, dst, op.getSrc(),
                                         modeValue, sourceFile, sourceLine)};
       builder.create<fir::CallOp>(loc, func, args);
       rewriter.eraseOp(op);
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index d9588942b21e81..8497aee2e2cf9c 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -281,4 +281,18 @@ func.func @_QPdesc_global_ptr() {
 // CHECK: %[[AHOST_BOXNONE:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[ADEV_BOXNONE]], %[[AHOST_BOXNONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
+func.func @_QPscalar_to_array() {
+  %c1_i32 = arith.constant 1 : i32
+  %c10 = arith.constant 10 : index
+  %0 = cuf.alloc !fir.array<10xi32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFscalar_to_arrayEa"} -> !fir.ref<!fir.array<10xi32>>
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %2:2 = hlfir.declare %0(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QFscalar_to_arrayEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+  cuf.data_transfer %c1_i32 to %2#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<!fir.array<10xi32>>
+  cuf.free %2#1 : !fir.ref<!fir.array<10xi32>> {data_attr = #cuf.cuda<device>}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPscalar_to_array()
+// CHECK: _FortranACUFDataTransferDescDescNoRealloc
+
 } // end of module

@clementval clementval merged commit ef8d88c into llvm:main Nov 7, 2024
11 checks passed
@clementval clementval deleted the cuf_scalar_to_array branch November 7, 2024 17:27
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
Do it via descriptor assignment until we have a more efficient way.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants