Skip to content

Commit ef8d88c

Browse files
authored
[flang][cuda] Support scalar to array data transfer (#115273)
Do it via descriptor assignment until we have a more efficient way.
1 parent 7c63b10 commit ef8d88c

File tree

2 files changed

+81
-38
lines changed

2 files changed

+81
-38
lines changed

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 67 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,53 @@ static mlir::Value getShapeFromDecl(mlir::Value src) {
448448
return mlir::Value{};
449449
}
450450

451+
static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
452+
cuf::DataTransferOp op,
453+
const mlir::SymbolTable &symtab) {
454+
auto mod = op->getParentOfType<mlir::ModuleOp>();
455+
mlir::Location loc = op.getLoc();
456+
fir::FirOpBuilder builder(rewriter, mod);
457+
mlir::Value addr;
458+
mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
459+
if (fir::isa_trivial(srcTy) &&
460+
mlir::matchPattern(op.getSrc().getDefiningOp(), mlir::m_Constant())) {
461+
// Put constant in memory if it is not.
462+
mlir::Value alloc = builder.createTemporary(loc, srcTy);
463+
builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
464+
addr = alloc;
465+
} else {
466+
addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
467+
}
468+
llvm::SmallVector<mlir::Value> lenParams;
469+
mlir::Type boxTy = fir::BoxType::get(srcTy);
470+
mlir::Value box =
471+
builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getSrc()),
472+
/*slice=*/nullptr, lenParams,
473+
/*tdesc=*/nullptr);
474+
mlir::Value src = builder.createTemporary(loc, box.getType());
475+
builder.create<fir::StoreOp>(loc, box, src);
476+
return src;
477+
}
478+
479+
static mlir::Value emboxDst(mlir::PatternRewriter &rewriter,
480+
cuf::DataTransferOp op,
481+
const mlir::SymbolTable &symtab) {
482+
auto mod = op->getParentOfType<mlir::ModuleOp>();
483+
mlir::Location loc = op.getLoc();
484+
fir::FirOpBuilder builder(rewriter, mod);
485+
mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
486+
mlir::Value dstAddr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
487+
mlir::Type dstBoxTy = fir::BoxType::get(dstTy);
488+
llvm::SmallVector<mlir::Value> lenParams;
489+
mlir::Value dstBox =
490+
builder.createBox(loc, dstBoxTy, dstAddr, getShapeFromDecl(op.getDst()),
491+
/*slice=*/nullptr, lenParams,
492+
/*tdesc=*/nullptr);
493+
mlir::Value dst = builder.createTemporary(loc, dstBox.getType());
494+
builder.create<fir::StoreOp>(loc, dstBox, dst);
495+
return dst;
496+
}
497+
451498
struct CUFDataTransferOpConversion
452499
: public mlir::OpRewritePattern<cuf::DataTransferOp> {
453500
using OpRewritePattern::OpRewritePattern;
@@ -486,10 +533,22 @@ struct CUFDataTransferOpConversion
486533
!mlir::isa<fir::BaseBoxType>(dstTy)) {
487534

488535
if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) {
489-
// TODO: scalar to array data transfer.
490-
mlir::emitError(loc,
491-
"not yet implemented: scalar to array data transfer\n");
492-
return mlir::failure();
536+
// Initialization of an array from a scalar value should be implemented
537+
// via a kernel launch. Use the flan runtime via the Assign function
538+
// until we have more infrastructure.
539+
mlir::Value src = emboxSrc(rewriter, op, symtab);
540+
mlir::Value dst = emboxDst(rewriter, op, symtab);
541+
mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
542+
CUFDataTransferDescDescNoRealloc)>(loc, builder);
543+
auto fTy = func.getFunctionType();
544+
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
545+
mlir::Value sourceLine =
546+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
547+
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
548+
builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
549+
builder.create<fir::CallOp>(loc, func, args);
550+
rewriter.eraseOp(op);
551+
return mlir::success();
493552
}
494553

495554
mlir::Type i64Ty = builder.getI64Type();
@@ -548,29 +607,8 @@ struct CUFDataTransferOpConversion
548607
mlir::Value dst = op.getDst();
549608
mlir::Value src = op.getSrc();
550609

551-
if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
552-
// If src is not a descriptor, create one.
553-
mlir::Value addr;
554-
if (fir::isa_trivial(srcTy) &&
555-
mlir::matchPattern(op.getSrc().getDefiningOp(),
556-
mlir::m_Constant())) {
557-
// Put constant in memory if it is not.
558-
mlir::Value alloc = builder.createTemporary(loc, srcTy);
559-
builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
560-
addr = alloc;
561-
} else {
562-
addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
563-
}
564-
mlir::Type boxTy = fir::BoxType::get(srcTy);
565-
llvm::SmallVector<mlir::Value> lenParams;
566-
mlir::Value box =
567-
builder.createBox(loc, boxTy, addr, getShapeFromDecl(src),
568-
/*slice=*/nullptr, lenParams,
569-
/*tdesc=*/nullptr);
570-
mlir::Value memBox = builder.createTemporary(loc, box.getType());
571-
builder.create<fir::StoreOp>(loc, box, memBox);
572-
src = memBox;
573-
}
610+
if (!mlir::isa<fir::BaseBoxType>(srcTy))
611+
src = emboxSrc(rewriter, op, symtab);
574612

575613
auto fTy = func.getFunctionType();
576614
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
@@ -582,16 +620,7 @@ struct CUFDataTransferOpConversion
582620
rewriter.eraseOp(op);
583621
} else {
584622
// Transfer from a descriptor.
585-
586-
mlir::Value addr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
587-
mlir::Type boxTy = fir::BoxType::get(dstTy);
588-
llvm::SmallVector<mlir::Value> lenParams;
589-
mlir::Value box =
590-
builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getDst()),
591-
/*slice=*/nullptr, lenParams,
592-
/*tdesc=*/nullptr);
593-
mlir::Value memBox = builder.createTemporary(loc, box.getType());
594-
builder.create<fir::StoreOp>(loc, box, memBox);
623+
mlir::Value dst = emboxDst(rewriter, op, symtab);
595624

596625
mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
597626
CUFDataTransferDescDescNoRealloc)>(loc, builder);
@@ -601,7 +630,7 @@ struct CUFDataTransferOpConversion
601630
mlir::Value sourceLine =
602631
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
603632
llvm::SmallVector<mlir::Value> args{
604-
fir::runtime::createArguments(builder, loc, fTy, memBox, op.getSrc(),
633+
fir::runtime::createArguments(builder, loc, fTy, dst, op.getSrc(),
605634
modeValue, sourceFile, sourceLine)};
606635
builder.create<fir::CallOp>(loc, func, args);
607636
rewriter.eraseOp(op);

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,4 +281,18 @@ func.func @_QPdesc_global_ptr() {
281281
// CHECK: %[[AHOST_BOXNONE:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
282282
// CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[ADEV_BOXNONE]], %[[AHOST_BOXNONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
283283

284+
func.func @_QPscalar_to_array() {
285+
%c1_i32 = arith.constant 1 : i32
286+
%c10 = arith.constant 10 : index
287+
%0 = cuf.alloc !fir.array<10xi32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFscalar_to_arrayEa"} -> !fir.ref<!fir.array<10xi32>>
288+
%1 = fir.shape %c10 : (index) -> !fir.shape<1>
289+
%2:2 = hlfir.declare %0(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QFscalar_to_arrayEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
290+
cuf.data_transfer %c1_i32 to %2#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<!fir.array<10xi32>>
291+
cuf.free %2#1 : !fir.ref<!fir.array<10xi32>> {data_attr = #cuf.cuda<device>}
292+
return
293+
}
294+
295+
// CHECK-LABEL: func.func @_QPscalar_to_array()
296+
// CHECK: _FortranACUFDataTransferDescDescNoRealloc
297+
284298
} // end of module

0 commit comments

Comments
 (0)