Skip to content

Commit d4eb430

Browse files
authored
[flang][cuda] Support derived type in cuf.alloc (#115550)
Number of bytes to allocate was not computed when using `cuf.alloc` with a derived type. Update the conversion to compute the number of bytes and emit an error when type is not supported.
1 parent fef4c8a commit d4eb430

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,13 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
337337
seqTy.getConstantArraySize());
338338
}
339339
bytes = rewriter.create<mlir::arith::MulIOp>(loc, nbElem, width);
340+
} else if (fir::isa_derived(op.getInType())) {
341+
mlir::Type structTy = typeConverter->convertType(op.getInType());
342+
std::size_t structSize = dl->getTypeSizeInBits(structTy) / 8;
343+
bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
344+
structSize);
345+
} else {
346+
mlir::emitError(loc, "unsupported type in cuf.alloc\n");
340347
}
341348
mlir::func::FuncOp func =
342349
fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);

flang/test/Fir/CUDA/cuda-alloc-free.fir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,16 @@ func.func @_QPsub3(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<
6161
// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
6262
// CHECK: fir.call @_FortranACUFMemFree
6363

64+
func.func @_QPtest_type() {
65+
%0 = cuf.alloc !fir.type<_QMbarTcmplx{id:i32,c:complex<f32>}> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_typeEa"} -> !fir.ref<!fir.type<_QMbarTcmplx{id:i32,c:complex<f32>}>>
66+
%1 = fir.declare %0 {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_typeEa"} : (!fir.ref<!fir.type<_QMbarTcmplx{id:i32,c:complex<f32>}>>) -> !fir.ref<!fir.type<_QMbarTcmplx{id:i32,c:complex<f32>}>>
67+
cuf.free %1 : !fir.ref<!fir.type<_QMbarTcmplx{id:i32,c:complex<f32>}>> {data_attr = #cuf.cuda<device>}
68+
return
69+
}
70+
71+
// CHECK-LABEL: func.func @_QPtest_type()
72+
// CHECK: %[[BYTES:.*]] = arith.constant 12 : index
73+
// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64
74+
// CHECK: fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
75+
6476
} // end module

0 commit comments

Comments
 (0)