Skip to content

Commit 39e254e

Browse files
authored
[flang][cuda] Convert cuf.alloc and cuf.free for scalar and arrays (#110055)
This patch adds more conversion of cuf.alloc and cuf.free for scalars, constant size arrays and dynamic size arrays
1 parent 5d19d55 commit 39e254e

File tree

4 files changed

+169
-60
lines changed

4 files changed

+169
-60
lines changed

flang/include/flang/Runtime/CUDA/common.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef FORTRAN_RUNTIME_CUDA_COMMON_H_
1010
#define FORTRAN_RUNTIME_CUDA_COMMON_H_
1111

12+
#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
1213
#include "flang/Runtime/descriptor.h"
1314
#include "flang/Runtime/entry-names.h"
1415

@@ -34,4 +35,16 @@ static constexpr unsigned kDeviceToDevice = 2;
3435
terminator.Crash("'%s' failed with '%s'", #expr, name); \
3536
}(expr)
3637

38+
static inline unsigned getMemType(cuf::DataAttribute attr) {
39+
if (attr == cuf::DataAttribute::Device)
40+
return kMemTypeDevice;
41+
if (attr == cuf::DataAttribute::Managed)
42+
return kMemTypeManaged;
43+
if (attr == cuf::DataAttribute::Unified)
44+
return kMemTypeUnified;
45+
if (attr == cuf::DataAttribute::Pinned)
46+
return kMemTypePinned;
47+
llvm::report_fatal_error("unsupported memory type");
48+
}
49+
3750
#endif // FORTRAN_RUNTIME_CUDA_COMMON_H_

flang/lib/Optimizer/Transforms/CufOpConversion.cpp

Lines changed: 92 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,29 @@ static bool inDeviceContext(mlir::Operation *op) {
183183
return false;
184184
}
185185

186+
static int computeWidth(mlir::Location loc, mlir::Type type,
187+
fir::KindMapping &kindMap) {
188+
auto eleTy = fir::unwrapSequenceType(type);
189+
int width = 0;
190+
if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
191+
width = t.getWidth() / 8;
192+
} else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
193+
width = t.getWidth() / 8;
194+
} else if (eleTy.isInteger(1)) {
195+
width = 1;
196+
} else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
197+
int kind = t.getFKind();
198+
width = kindMap.getLogicalBitsize(kind) / 8;
199+
} else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
200+
int kind = t.getFKind();
201+
int elemSize = kindMap.getRealBitsize(kind) / 8;
202+
width = 2 * elemSize;
203+
} else {
204+
llvm::report_fatal_error("unsupported type");
205+
}
206+
return width;
207+
}
208+
186209
struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
187210
using OpRewritePattern::OpRewritePattern;
188211

@@ -193,11 +216,6 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
193216
mlir::LogicalResult
194217
matchAndRewrite(cuf::AllocOp op,
195218
mlir::PatternRewriter &rewriter) const override {
196-
auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
197-
198-
// Only convert cuf.alloc that allocates a descriptor.
199-
if (!boxTy)
200-
return failure();
201219

202220
if (inDeviceContext(op.getOperation())) {
203221
// In device context just replace the cuf.alloc operation with a fir.alloc
@@ -212,11 +230,56 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
212230
auto mod = op->getParentOfType<mlir::ModuleOp>();
213231
fir::FirOpBuilder builder(rewriter, mod);
214232
mlir::Location loc = op.getLoc();
233+
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
234+
235+
if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
236+
// Convert scalar and known size array allocations.
237+
mlir::Value bytes;
238+
fir::KindMapping kindMap{fir::getKindMapping(mod)};
239+
if (fir::isa_trivial(op.getInType())) {
240+
int width = computeWidth(loc, op.getInType(), kindMap);
241+
bytes =
242+
builder.createIntegerConstant(loc, builder.getIndexType(), width);
243+
} else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
244+
op.getInType())) {
245+
mlir::Value width = builder.createIntegerConstant(
246+
loc, builder.getIndexType(),
247+
computeWidth(loc, seqTy.getEleTy(), kindMap));
248+
mlir::Value nbElem;
249+
if (fir::sequenceWithNonConstantShape(seqTy)) {
250+
assert(!op.getShape().empty() && "expect shape with dynamic arrays");
251+
nbElem = builder.loadIfRef(loc, op.getShape()[0]);
252+
for (unsigned i = 1; i < op.getShape().size(); ++i) {
253+
nbElem = rewriter.create<mlir::arith::MulIOp>(
254+
loc, nbElem, builder.loadIfRef(loc, op.getShape()[i]));
255+
}
256+
} else {
257+
nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
258+
seqTy.getConstantArraySize());
259+
}
260+
bytes = rewriter.create<mlir::arith::MulIOp>(loc, nbElem, width);
261+
}
262+
mlir::func::FuncOp func =
263+
fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
264+
auto fTy = func.getFunctionType();
265+
mlir::Value sourceLine =
266+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
267+
mlir::Value memTy = builder.createIntegerConstant(
268+
loc, builder.getI32Type(), getMemType(op.getDataAttr()));
269+
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
270+
builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
271+
auto callOp = builder.create<fir::CallOp>(loc, func, args);
272+
auto convOp = builder.createConvert(loc, op.getResult().getType(),
273+
callOp.getResult(0));
274+
rewriter.replaceOp(op, convOp);
275+
return mlir::success();
276+
}
277+
278+
// Convert descriptor allocations to function call.
279+
auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
215280
mlir::func::FuncOp func =
216281
fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDesciptor)>(loc, builder);
217-
218282
auto fTy = func.getFunctionType();
219-
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
220283
mlir::Value sourceLine =
221284
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
222285

@@ -245,26 +308,39 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
245308
mlir::LogicalResult
246309
matchAndRewrite(cuf::FreeOp op,
247310
mlir::PatternRewriter &rewriter) const override {
248-
// Only convert cuf.free on descriptor.
249-
if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
250-
return failure();
251-
auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
252-
if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy()))
253-
return failure();
254-
255311
if (inDeviceContext(op.getOperation())) {
256312
rewriter.eraseOp(op);
257313
return mlir::success();
258314
}
259315

316+
if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
317+
return failure();
318+
260319
auto mod = op->getParentOfType<mlir::ModuleOp>();
261320
fir::FirOpBuilder builder(rewriter, mod);
262321
mlir::Location loc = op.getLoc();
322+
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
323+
324+
auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
325+
if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
326+
mlir::func::FuncOp func =
327+
fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
328+
auto fTy = func.getFunctionType();
329+
mlir::Value sourceLine =
330+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
331+
mlir::Value memTy = builder.createIntegerConstant(
332+
loc, builder.getI32Type(), getMemType(op.getDataAttr()));
333+
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
334+
builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
335+
builder.create<fir::CallOp>(loc, func, args);
336+
rewriter.eraseOp(op);
337+
return mlir::success();
338+
}
339+
340+
// Convert cuf.free on descriptors.
263341
mlir::func::FuncOp func =
264342
fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDesciptor)>(loc, builder);
265-
266343
auto fTy = func.getFunctionType();
267-
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
268344
mlir::Value sourceLine =
269345
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
270346
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
@@ -275,29 +351,6 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
275351
}
276352
};
277353

278-
static int computeWidth(mlir::Location loc, mlir::Type type,
279-
fir::KindMapping &kindMap) {
280-
auto eleTy = fir::unwrapSequenceType(type);
281-
int width = 0;
282-
if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
283-
width = t.getWidth() / 8;
284-
} else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
285-
width = t.getWidth() / 8;
286-
} else if (eleTy.isInteger(1)) {
287-
width = 1;
288-
} else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
289-
int kind = t.getFKind();
290-
width = kindMap.getLogicalBitsize(kind) / 8;
291-
} else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
292-
int kind = t.getFKind();
293-
int elemSize = kindMap.getRealBitsize(kind) / 8;
294-
width = 2 * elemSize;
295-
} else {
296-
llvm::report_fatal_error("unsupported type");
297-
}
298-
return width;
299-
}
300-
301354
static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
302355
mlir::Location loc, mlir::Type toTy,
303356
mlir::Value val) {
@@ -456,16 +509,6 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
456509
fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
457510
fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
458511
/*forceUnifiedTBAATree=*/false, *dl);
459-
target.addDynamicallyLegalOp<cuf::AllocOp>([](::cuf::AllocOp op) {
460-
return !mlir::isa<fir::BaseBoxType>(op.getInType());
461-
});
462-
target.addDynamicallyLegalOp<cuf::FreeOp>([](::cuf::FreeOp op) {
463-
if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(
464-
op.getDevptr().getType())) {
465-
return !mlir::isa<fir::BaseBoxType>(refTy.getEleTy());
466-
}
467-
return true;
468-
});
469512
target.addDynamicallyLegalOp<cuf::DataTransferOp>(
470513
[](::cuf::DataTransferOp op) {
471514
mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// RUN: fir-opt --cuf-convert %s | FileCheck %s
2+
3+
module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
4+
5+
func.func @_QPsub1() {
6+
%0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} -> !fir.ref<i32>
7+
%1:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
8+
cuf.free %1#1 : !fir.ref<i32> {data_attr = #cuf.cuda<device>}
9+
return
10+
}
11+
12+
// CHECK-LABEL: func.func @_QPsub1()
13+
// CHECK: %[[BYTES:.*]] = fir.convert %c4{{.*}} : (index) -> i64
14+
// CHECK: %[[ALLOC:.*]] = fir.call @_FortranACUFMemAlloc(%[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
15+
// CHECK: %[[CONV:.*]] = fir.convert %3 : (!fir.llvm_ptr<i8>) -> !fir.ref<i32>
16+
// CHECK: %[[DECL:.*]]:2 = hlfir.declare %[[CONV]] {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
17+
// CHECK: %[[DEVPTR:.*]] = fir.convert %[[DECL]]#1 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
18+
// CHECK: fir.call @_FortranACUFMemFree(%[[DEVPTR]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, i32, !fir.ref<i8>, i32) -> none
19+
20+
func.func @_QPsub2() {
21+
%0 = cuf.alloc !fir.array<10xf32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref<!fir.array<10xf32>>
22+
cuf.free %0 : !fir.ref<!fir.array<10xf32>> {data_attr = #cuf.cuda<device>}
23+
return
24+
}
25+
26+
// CHECK-LABEL: func.func @_QPsub2()
27+
// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c4{{.*}} : index
28+
// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64
29+
// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
30+
// CHECK: fir.call @_FortranACUFMemFree
31+
32+
func.func @_QPsub3(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<i32> {fir.bindc_name = "m"}) {
33+
%0 = fir.dummy_scope : !fir.dscope
34+
%1:2 = hlfir.declare %arg0 dummy_scope %0 {uniq_name = "_QFsub3En"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
35+
%2:2 = hlfir.declare %arg1 dummy_scope %0 {uniq_name = "_QFsub3Em"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
36+
%3 = fir.load %1#0 : !fir.ref<i32>
37+
%4 = fir.convert %3 : (i32) -> i64
38+
%5 = fir.convert %4 : (i64) -> index
39+
%c0 = arith.constant 0 : index
40+
%6 = arith.cmpi sgt, %5, %c0 : index
41+
%7 = arith.select %6, %5, %c0 : index
42+
%8 = fir.load %2#0 : !fir.ref<i32>
43+
%9 = fir.convert %8 : (i32) -> i64
44+
%10 = fir.convert %9 : (i64) -> index
45+
%c0_0 = arith.constant 0 : index
46+
%11 = arith.cmpi sgt, %10, %c0_0 : index
47+
%12 = arith.select %11, %10, %c0_0 : index
48+
%13 = cuf.alloc !fir.array<?x?xi32>, %7, %12 : index, index {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eidev"} -> !fir.ref<!fir.array<?x?xi32>>
49+
%14 = fir.shape %7, %12 : (index, index) -> !fir.shape<2>
50+
%15:2 = hlfir.declare %13(%14) {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eidev"} : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.ref<!fir.array<?x?xi32>>)
51+
cuf.free %15#1 : !fir.ref<!fir.array<?x?xi32>> {data_attr = #cuf.cuda<device>}
52+
return
53+
}
54+
55+
// CHECK-LABEL: func.func @_QPsub3
56+
// CHECK: %[[N:.*]] = arith.select
57+
// CHECK: %[[M:.*]] = arith.select
58+
// CHECK: %[[NBELEM:.*]] = arith.muli %[[N]], %[[M]] : index
59+
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %c4{{.*}} : index
60+
// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64
61+
// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
62+
// CHECK: fir.call @_FortranACUFMemFree
63+
64+
} // end module

flang/test/Fir/CUDA/cuda-allocate.fir

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,6 @@ func.func @_QPsub1() {
2626
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
2727
// CHECK: fir.call @_FortranACUFFreeDesciptor(%[[BOX_NONE]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<i8>, i32) -> none
2828

29-
// Check operations that should not be transformed yet.
30-
func.func @_QPsub2() {
31-
%0 = cuf.alloc !fir.array<10xf32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref<!fir.array<10xf32>>
32-
cuf.free %0 : !fir.ref<!fir.array<10xf32>> {data_attr = #cuf.cuda<device>}
33-
return
34-
}
35-
36-
// CHECK-LABEL: func.func @_QPsub2()
37-
// CHECK: cuf.alloc !fir.array<10xf32>
38-
// CHECK: cuf.free %{{.*}} : !fir.ref<!fir.array<10xf32>>
39-
4029
fir.global @_QMmod1Ea {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xf32>>> {
4130
%0 = fir.zero_bits !fir.heap<!fir.array<?xf32>>
4231
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)