Skip to content

Commit 93b2e47

Browse files
authored
[flang][cuda] Avoid assign element mismatch when doing data transfer from a constant (#128252)
Currently when we do a CUDA data transfer from a constant, we embox it and delegate the assignment to the runtime. When the type of the constant is not exactly the same as the destination descriptor, the runtime will emit an assignment mismatch error. Convert the constant when necessary so the assignment is fine.
1 parent 5f8da7e commit 93b2e47

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,8 @@ static mlir::Value getShapeFromDecl(mlir::Value src) {
541541

542542
static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
543543
cuf::DataTransferOp op,
544-
const mlir::SymbolTable &symtab) {
544+
const mlir::SymbolTable &symtab,
545+
mlir::Type dstEleTy = nullptr) {
545546
auto mod = op->getParentOfType<mlir::ModuleOp>();
546547
mlir::Location loc = op.getLoc();
547548
fir::FirOpBuilder builder(rewriter, mod);
@@ -555,11 +556,21 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
555556
// from a LOGICAL constant. Store it as a fir.logical.
556557
srcTy = fir::LogicalType::get(rewriter.getContext(), 4);
557558
src = createConvertOp(rewriter, loc, srcTy, src);
559+
addr = builder.createTemporary(loc, srcTy);
560+
builder.create<fir::StoreOp>(loc, src, addr);
561+
} else {
562+
if (dstEleTy && fir::isa_trivial(dstEleTy) && srcTy != dstEleTy) {
563+
// Use dstEleTy and convert to avoid assign mismatch.
564+
addr = builder.createTemporary(loc, dstEleTy);
565+
auto conv = builder.create<fir::ConvertOp>(loc, dstEleTy, src);
566+
builder.create<fir::StoreOp>(loc, conv, addr);
567+
srcTy = dstEleTy;
568+
} else {
569+
// Put constant in memory if it is not.
570+
addr = builder.createTemporary(loc, srcTy);
571+
builder.create<fir::StoreOp>(loc, src, addr);
572+
}
558573
}
559-
// Put constant in memory if it is not.
560-
mlir::Value alloc = builder.createTemporary(loc, srcTy);
561-
builder.create<fir::StoreOp>(loc, src, alloc);
562-
addr = alloc;
563574
} else {
564575
addr = op.getSrc();
565576
}
@@ -729,7 +740,7 @@ struct CUFDataTransferOpConversion
729740
};
730741

731742
// Conversion of data transfer involving at least one descriptor.
732-
if (mlir::isa<fir::BaseBoxType>(dstTy)) {
743+
if (auto dstBoxTy = mlir::dyn_cast<fir::BaseBoxType>(dstTy)) {
733744
// Transfer to a descriptor.
734745
mlir::func::FuncOp func =
735746
isDstGlobal(op)
@@ -740,7 +751,8 @@ struct CUFDataTransferOpConversion
740751
mlir::Value dst = op.getDst();
741752
mlir::Value src = op.getSrc();
742753
if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
743-
src = emboxSrc(rewriter, op, symtab);
754+
mlir::Type dstEleTy = fir::unwrapInnerType(dstBoxTy.getEleTy());
755+
src = emboxSrc(rewriter, op, symtab, dstEleTy);
744756
if (fir::isa_trivial(srcTy))
745757
func = fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
746758
loc, builder);

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,4 +582,26 @@ func.func @_QPchecksums(%arg0: !fir.box<!fir.array<?xf64>> {cuf.data_attr = #cuf
582582
// CHECK: %[[SRC:.*]] = fir.convert %{{.*}} : (!fir.ref<!fir.box<!fir.array<?xf64>>>) -> !fir.ref<!fir.box<none>>
583583
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
584584

585+
func.func @_QPsub20() {
586+
%0 = cuf.alloc !fir.box<!fir.heap<f32>> {bindc_name = "r", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub20Er"} -> !fir.ref<!fir.box<!fir.heap<f32>>>
587+
%1 = fir.zero_bits !fir.heap<f32>
588+
%2 = fir.embox %1 {allocator_idx = 2 : i32} : (!fir.heap<f32>) -> !fir.box<!fir.heap<f32>>
589+
fir.store %2 to %0 : !fir.ref<!fir.box<!fir.heap<f32>>>
590+
%3:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub20Er"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
591+
%c0_i32 = arith.constant 0 : i32
592+
cuf.data_transfer %c0_i32 to %3#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<!fir.box<!fir.heap<f32>>>
593+
return
594+
}
595+
596+
// CHECK-LABEL:func.func @_QPsub20
597+
// CHECK: %[[BOX_ALLOCA:.*]] = fir.alloca !fir.box<f32>
598+
// CHECK: %[[TMP:.*]] = fir.alloca f32
599+
// CHECK: %[[CONV:.*]] = fir.convert %c0{{.*}} : (i32) -> f32
600+
// CHECK: fir.store %[[CONV]] to %[[TMP]] : !fir.ref<f32>
601+
// CHECK: %[[BOX:.*]] = fir.embox %[[TMP]] : (!fir.ref<f32>) -> !fir.box<f32>
602+
// CHECK: fir.store %[[BOX]] to %[[BOX_ALLOCA]] : !fir.ref<!fir.box<f32>>
603+
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[BOX_ALLOCA]] : (!fir.ref<!fir.box<f32>>) -> !fir.ref<!fir.box<none>>
604+
// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%13, %[[BOX_NONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
605+
585606
} // end of module
607+

0 commit comments

Comments
 (0)