Skip to content

Commit e8469f1

Browse files
authored
[flang][cuda] Add support for character type in cuf.alloc and cuf.data_transfer (#116277)
Add support for character type in bytes computation
1 parent 4b50ec4 commit e8469f1

File tree

3 files changed

+45
-15
lines changed

3 files changed

+45
-15
lines changed

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -268,24 +268,23 @@ static bool inDeviceContext(mlir::Operation *op) {
268268
static int computeWidth(mlir::Location loc, mlir::Type type,
269269
fir::KindMapping &kindMap) {
270270
auto eleTy = fir::unwrapSequenceType(type);
271-
int width = 0;
272-
if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
273-
width = t.getWidth() / 8;
274-
} else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
275-
width = t.getWidth() / 8;
276-
} else if (eleTy.isInteger(1)) {
277-
width = 1;
278-
} else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
279-
int kind = t.getFKind();
280-
width = kindMap.getLogicalBitsize(kind) / 8;
281-
} else if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
271+
if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
272+
return t.getWidth() / 8;
273+
if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
274+
return t.getWidth() / 8;
275+
if (eleTy.isInteger(1))
276+
return 1;
277+
if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
278+
return kindMap.getLogicalBitsize(t.getFKind()) / 8;
279+
if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
282280
int elemSize =
283281
mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
284-
width = 2 * elemSize;
285-
} else {
286-
mlir::emitError(loc, "unsupported type");
282+
return 2 * elemSize;
287283
}
288-
return width;
284+
if (auto t{mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)})
285+
return kindMap.getCharacterBitsize(t.getFKind()) / 8;
286+
mlir::emitError(loc, "unsupported type");
287+
return 0;
289288
}
290289

291290
struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {

flang/test/Fir/CUDA/cuda-alloc-free.fir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,15 @@ gpu.module @cuda_device_mod [#nvvm.target] {
8383
// CHECK-LABEL: gpu.func @_QMalloc() kernel
8484
// CHECK: fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QMallocEa"}
8585

86+
func.func @_QQalloc_char() attributes {fir.bindc_name = "alloc_char"} {
87+
%c1 = arith.constant 1 : index
88+
%0 = cuf.alloc !fir.array<10x!fir.char<1>>(%c1 : index) {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFEa"} -> !fir.ref<!fir.array<10x!fir.char<1>>>
89+
return
90+
}
91+
92+
// CHECK-LABEL: func.func @_QQalloc_char()
93+
// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c1{{.*}} : index
94+
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
95+
// CHECK: fir.call @_FortranACUFMemAlloc(%[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
96+
8697
} // end module

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,25 @@ func.func @_QPdevice_addr_conv() {
385385
// CHECK: fir.embox %[[DEV_ADDR_CONV]](%{{.*}}) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<4xf32>>
386386
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc
387387

388+
389+
func.func @_QQchar_transfer() attributes {fir.bindc_name = "char_transfer"} {
390+
%c1 = arith.constant 1 : index
391+
%c10 = arith.constant 10 : index
392+
%0 = cuf.alloc !fir.array<10x!fir.char<1>>(%c1 : index) {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFEa"} -> !fir.ref<!fir.array<10x!fir.char<1>>>
393+
%1 = fir.shape %c10 : (index) -> !fir.shape<1>
394+
%2 = fir.declare %0(%1) typeparams %c1 {data_attr = #cuf.cuda<device>, uniq_name = "_QFEa"} : (!fir.ref<!fir.array<10x!fir.char<1>>>, !fir.shape<1>, index) -> !fir.ref<!fir.array<10x!fir.char<1>>>
395+
%3 = fir.alloca !fir.array<10x!fir.char<1>> {bindc_name = "b", uniq_name = "_QFEb"}
396+
%4 = fir.declare %3(%1) typeparams %c1 {uniq_name = "_QFEb"} : (!fir.ref<!fir.array<10x!fir.char<1>>>, !fir.shape<1>, index) -> !fir.ref<!fir.array<10x!fir.char<1>>>
397+
cuf.data_transfer %4 to %2 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10x!fir.char<1>>>, !fir.ref<!fir.array<10x!fir.char<1>>>
398+
cuf.free %2 : !fir.ref<!fir.array<10x!fir.char<1>>> {data_attr = #cuf.cuda<device>}
399+
return
400+
}
401+
402+
// CHECK-LABEL: func.func @_QQchar_transfer()
403+
// CHECK: fir.call @_FortranACUFMemAlloc
404+
// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c1{{.*}} : i64
405+
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%{{.*}}, %{{.*}}, %[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
406+
388407
func.func @_QPdevmul(%arg0: !fir.ref<!fir.array<1x?xf32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "wa"}, %arg2: !fir.ref<i32> {fir.bindc_name = "wb"}) {
389408
%c0_i64 = arith.constant 0 : i64
390409
%c1_i32 = arith.constant 1 : i32
@@ -424,4 +443,5 @@ func.func @_QPdevmul(%arg0: !fir.ref<!fir.array<1x?xf32>> {fir.bindc_name = "b"}
424443
// CHECK: %[[SRC:.*]] = fir.convert %[[ALLOCA]] : (!fir.ref<!fir.box<!fir.array<?x?xf32>>>) -> !fir.ref<!fir.box<none>>
425444
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%{{.*}}, %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
426445

446+
427447
} // end of module

0 commit comments

Comments
 (0)