Skip to content

[flang][cuda] Support scalar to array data transfer #115273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 67 additions & 38 deletions flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,53 @@ static mlir::Value getShapeFromDecl(mlir::Value src) {
return mlir::Value{};
}

static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
cuf::DataTransferOp op,
const mlir::SymbolTable &symtab) {
auto mod = op->getParentOfType<mlir::ModuleOp>();
mlir::Location loc = op.getLoc();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Value addr;
mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
if (fir::isa_trivial(srcTy) &&
mlir::matchPattern(op.getSrc().getDefiningOp(), mlir::m_Constant())) {
// Put constant in memory if it is not.
mlir::Value alloc = builder.createTemporary(loc, srcTy);
builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
addr = alloc;
} else {
addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
}
llvm::SmallVector<mlir::Value> lenParams;
mlir::Type boxTy = fir::BoxType::get(srcTy);
mlir::Value box =
builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getSrc()),
/*slice=*/nullptr, lenParams,
/*tdesc=*/nullptr);
mlir::Value src = builder.createTemporary(loc, box.getType());
builder.create<fir::StoreOp>(loc, box, src);
return src;
}

static mlir::Value emboxDst(mlir::PatternRewriter &rewriter,
cuf::DataTransferOp op,
const mlir::SymbolTable &symtab) {
auto mod = op->getParentOfType<mlir::ModuleOp>();
mlir::Location loc = op.getLoc();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
mlir::Value dstAddr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
mlir::Type dstBoxTy = fir::BoxType::get(dstTy);
llvm::SmallVector<mlir::Value> lenParams;
mlir::Value dstBox =
builder.createBox(loc, dstBoxTy, dstAddr, getShapeFromDecl(op.getDst()),
/*slice=*/nullptr, lenParams,
/*tdesc=*/nullptr);
mlir::Value dst = builder.createTemporary(loc, dstBox.getType());
builder.create<fir::StoreOp>(loc, dstBox, dst);
return dst;
}

struct CUFDataTransferOpConversion
: public mlir::OpRewritePattern<cuf::DataTransferOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -486,10 +533,22 @@ struct CUFDataTransferOpConversion
!mlir::isa<fir::BaseBoxType>(dstTy)) {

if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) {
// TODO: scalar to array data transfer.
mlir::emitError(loc,
"not yet implemented: scalar to array data transfer\n");
return mlir::failure();
// Initialization of an array from a scalar value should be implemented
// via a kernel launch. Use the flan runtime via the Assign function
// until we have more infrastructure.
mlir::Value src = emboxSrc(rewriter, op, symtab);
mlir::Value dst = emboxDst(rewriter, op, symtab);
mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
CUFDataTransferDescDescNoRealloc)>(loc, builder);
auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
return mlir::success();
}

mlir::Type i64Ty = builder.getI64Type();
Expand Down Expand Up @@ -548,29 +607,8 @@ struct CUFDataTransferOpConversion
mlir::Value dst = op.getDst();
mlir::Value src = op.getSrc();

if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
// If src is not a descriptor, create one.
mlir::Value addr;
if (fir::isa_trivial(srcTy) &&
mlir::matchPattern(op.getSrc().getDefiningOp(),
mlir::m_Constant())) {
// Put constant in memory if it is not.
mlir::Value alloc = builder.createTemporary(loc, srcTy);
builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
addr = alloc;
} else {
addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
}
mlir::Type boxTy = fir::BoxType::get(srcTy);
llvm::SmallVector<mlir::Value> lenParams;
mlir::Value box =
builder.createBox(loc, boxTy, addr, getShapeFromDecl(src),
/*slice=*/nullptr, lenParams,
/*tdesc=*/nullptr);
mlir::Value memBox = builder.createTemporary(loc, box.getType());
builder.create<fir::StoreOp>(loc, box, memBox);
src = memBox;
}
if (!mlir::isa<fir::BaseBoxType>(srcTy))
src = emboxSrc(rewriter, op, symtab);

auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
Expand All @@ -582,16 +620,7 @@ struct CUFDataTransferOpConversion
rewriter.eraseOp(op);
} else {
// Transfer from a descriptor.

mlir::Value addr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
mlir::Type boxTy = fir::BoxType::get(dstTy);
llvm::SmallVector<mlir::Value> lenParams;
mlir::Value box =
builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getDst()),
/*slice=*/nullptr, lenParams,
/*tdesc=*/nullptr);
mlir::Value memBox = builder.createTemporary(loc, box.getType());
builder.create<fir::StoreOp>(loc, box, memBox);
mlir::Value dst = emboxDst(rewriter, op, symtab);

mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
CUFDataTransferDescDescNoRealloc)>(loc, builder);
Expand All @@ -601,7 +630,7 @@ struct CUFDataTransferOpConversion
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
llvm::SmallVector<mlir::Value> args{
fir::runtime::createArguments(builder, loc, fTy, memBox, op.getSrc(),
fir::runtime::createArguments(builder, loc, fTy, dst, op.getSrc(),
modeValue, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
Expand Down
14 changes: 14 additions & 0 deletions flang/test/Fir/CUDA/cuda-data-transfer.fir
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,18 @@ func.func @_QPdesc_global_ptr() {
// CHECK: %[[AHOST_BOXNONE:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[ADEV_BOXNONE]], %[[AHOST_BOXNONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none

func.func @_QPscalar_to_array() {
%c1_i32 = arith.constant 1 : i32
%c10 = arith.constant 10 : index
%0 = cuf.alloc !fir.array<10xi32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFscalar_to_arrayEa"} -> !fir.ref<!fir.array<10xi32>>
%1 = fir.shape %c10 : (index) -> !fir.shape<1>
%2:2 = hlfir.declare %0(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QFscalar_to_arrayEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
cuf.data_transfer %c1_i32 to %2#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<!fir.array<10xi32>>
cuf.free %2#1 : !fir.ref<!fir.array<10xi32>> {data_attr = #cuf.cuda<device>}
return
}

// CHECK-LABEL: func.func @_QPscalar_to_array()
// CHECK: _FortranACUFDataTransferDescDescNoRealloc

} // end of module
Loading