-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang][cuda] Convert cuf.alloc and cuf.free for scalar and arrays #110055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesThis patch adds more conversion of cuf.alloc and cuf.free for scalars, constant size arrays and dynamic size arrays Full diff: https://github.com/llvm/llvm-project/pull/110055.diff 3 Files Affected:
diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
index f8ace2dd96a0d8..fff026bbda2d40 100644
--- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -183,6 +183,29 @@ static bool inDeviceContext(mlir::Operation *op) {
return false;
}
+static int computeWidth(mlir::Location loc, mlir::Type type,
+ fir::KindMapping &kindMap) {
+ auto eleTy = fir::unwrapSequenceType(type);
+ int width = 0;
+ if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
+ width = t.getWidth() / 8;
+ } else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
+ width = t.getWidth() / 8;
+ } else if (eleTy.isInteger(1)) {
+ width = 1;
+ } else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
+ int kind = t.getFKind();
+ width = kindMap.getLogicalBitsize(kind) / 8;
+ } else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
+ int kind = t.getFKind();
+ int elemSize = kindMap.getRealBitsize(kind) / 8;
+ width = 2 * elemSize;
+ } else {
+ llvm::report_fatal_error("unsupported type");
+ }
+ return width;
+}
+
struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
using OpRewritePattern::OpRewritePattern;
@@ -193,11 +216,6 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
mlir::LogicalResult
matchAndRewrite(cuf::AllocOp op,
mlir::PatternRewriter &rewriter) const override {
- auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
-
- // Only convert cuf.alloc that allocates a descriptor.
- if (!boxTy)
- return failure();
if (inDeviceContext(op.getOperation())) {
// In device context just replace the cuf.alloc operation with a fir.alloc
@@ -212,11 +230,56 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+
+ if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
+ // Convert scalar and known size array allocations.
+ mlir::Value bytes;
+ fir::KindMapping kindMap{fir::getKindMapping(mod)};
+ if (fir::isa_trivial(op.getInType())) {
+ int width = computeWidth(loc, op.getInType(), kindMap);
+ bytes =
+ builder.createIntegerConstant(loc, builder.getIndexType(), width);
+ } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
+ op.getInType())) {
+ mlir::Value width = builder.createIntegerConstant(
+ loc, builder.getIndexType(),
+ computeWidth(loc, seqTy.getEleTy(), kindMap));
+ mlir::Value nbElem;
+ if (fir::sequenceWithNonConstantShape(seqTy)) {
+ assert(!op.getShape().empty() && "expect shape with dynamic arrays");
+ nbElem = builder.loadIfRef(loc, op.getShape()[0]);
+ for (unsigned i = 1; i < op.getShape().size(); ++i) {
+ nbElem = rewriter.create<mlir::arith::MulIOp>(
+ loc, nbElem, builder.loadIfRef(loc, op.getShape()[i]));
+ }
+ } else {
+ nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
+ seqTy.getConstantArraySize());
+ }
+ bytes = rewriter.create<mlir::arith::MulIOp>(loc, nbElem, width);
+ }
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+ mlir::Value memTy = builder.createIntegerConstant(
+ loc, builder.getI32Type(), kMemTypeDevice);
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
+ auto callOp = builder.create<fir::CallOp>(loc, func, args);
+ auto convOp = builder.createConvert(loc, op.getResult().getType(),
+ callOp.getResult(0));
+ rewriter.replaceOp(op, convOp);
+ return mlir::success();
+ }
+
+ // Convert descriptor allocations to function call.
+ auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDesciptor)>(loc, builder);
-
auto fTy = func.getFunctionType();
- mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
@@ -245,26 +308,39 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
mlir::LogicalResult
matchAndRewrite(cuf::FreeOp op,
mlir::PatternRewriter &rewriter) const override {
- // Only convert cuf.free on descriptor.
- if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
- return failure();
- auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
- if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy()))
- return failure();
-
if (inDeviceContext(op.getOperation())) {
rewriter.eraseOp(op);
return mlir::success();
}
+ if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
+ return failure();
+
auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+
+ auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
+ if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+ mlir::Value memTy = builder.createIntegerConstant(
+ loc, builder.getI32Type(), kMemTypeDevice);
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
+ builder.create<fir::CallOp>(loc, func, args);
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+
+ // Convert cuf.free on descriptors.
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDesciptor)>(loc, builder);
-
auto fTy = func.getFunctionType();
- mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
@@ -275,29 +351,6 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
}
};
-static int computeWidth(mlir::Location loc, mlir::Type type,
- fir::KindMapping &kindMap) {
- auto eleTy = fir::unwrapSequenceType(type);
- int width = 0;
- if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
- width = t.getWidth() / 8;
- } else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
- width = t.getWidth() / 8;
- } else if (eleTy.isInteger(1)) {
- width = 1;
- } else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
- int kind = t.getFKind();
- width = kindMap.getLogicalBitsize(kind) / 8;
- } else if (auto t{mlir::dyn_cast<fir::ComplexType>(eleTy)}) {
- int kind = t.getFKind();
- int elemSize = kindMap.getRealBitsize(kind) / 8;
- width = 2 * elemSize;
- } else {
- llvm::report_fatal_error("unsupported type");
- }
- return width;
-}
-
static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Type toTy,
mlir::Value val) {
@@ -456,16 +509,6 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
/*forceUnifiedTBAATree=*/false, *dl);
- target.addDynamicallyLegalOp<cuf::AllocOp>([](::cuf::AllocOp op) {
- return !mlir::isa<fir::BaseBoxType>(op.getInType());
- });
- target.addDynamicallyLegalOp<cuf::FreeOp>([](::cuf::FreeOp op) {
- if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(
- op.getDevptr().getType())) {
- return !mlir::isa<fir::BaseBoxType>(refTy.getEleTy());
- }
- return true;
- });
target.addDynamicallyLegalOp<cuf::DataTransferOp>(
[](::cuf::DataTransferOp op) {
mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
diff --git a/flang/test/Fir/CUDA/cuda-alloc-free.fir b/flang/test/Fir/CUDA/cuda-alloc-free.fir
new file mode 100644
index 00000000000000..25821418a40f11
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-alloc-free.fir
@@ -0,0 +1,64 @@
+// RUN: fir-opt --cuf-convert %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
+
+func.func @_QPsub1() {
+ %0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} -> !fir.ref<i32>
+ %1:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ cuf.free %1#1 : !fir.ref<i32> {data_attr = #cuf.cuda<device>}
+ return
+}
+
+// CHECK-LABEL: func.func @_QPsub1()
+// CHECK: %[[BYTES:.*]] = fir.convert %c4{{.*}} : (index) -> i64
+// CHECK: %[[ALLOC:.*]] = fir.call @_FortranACUFMemAlloc(%[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[CONV:.*]] = fir.convert %3 : (!fir.llvm_ptr<i8>) -> !fir.ref<i32>
+// CHECK: %[[DECL:.*]]:2 = hlfir.declare %[[CONV]] {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eidev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[DEVPTR:.*]] = fir.convert %[[DECL]]#1 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFMemFree(%[[DEVPTR]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, i32, !fir.ref<i8>, i32) -> none
+
+func.func @_QPsub2() {
+ %0 = cuf.alloc !fir.array<10xf32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref<!fir.array<10xf32>>
+ cuf.free %0 : !fir.ref<!fir.array<10xf32>> {data_attr = #cuf.cuda<device>}
+ return
+}
+
+// CHECK-LABEL: func.func @_QPsub2()
+// CHECK: %[[BYTES:.*]] = arith.muli %c10{{.*}}, %c4{{.*}} : index
+// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64
+// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFMemFree
+
+func.func @_QPsub3(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<i32> {fir.bindc_name = "m"}) {
+ %0 = fir.dummy_scope : !fir.dscope
+ %1:2 = hlfir.declare %arg0 dummy_scope %0 {uniq_name = "_QFsub3En"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %2:2 = hlfir.declare %arg1 dummy_scope %0 {uniq_name = "_QFsub3Em"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %3 = fir.load %1#0 : !fir.ref<i32>
+ %4 = fir.convert %3 : (i32) -> i64
+ %5 = fir.convert %4 : (i64) -> index
+ %c0 = arith.constant 0 : index
+ %6 = arith.cmpi sgt, %5, %c0 : index
+ %7 = arith.select %6, %5, %c0 : index
+ %8 = fir.load %2#0 : !fir.ref<i32>
+ %9 = fir.convert %8 : (i32) -> i64
+ %10 = fir.convert %9 : (i64) -> index
+ %c0_0 = arith.constant 0 : index
+ %11 = arith.cmpi sgt, %10, %c0_0 : index
+ %12 = arith.select %11, %10, %c0_0 : index
+ %13 = cuf.alloc !fir.array<?x?xi32>, %7, %12 : index, index {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eidev"} -> !fir.ref<!fir.array<?x?xi32>>
+ %14 = fir.shape %7, %12 : (index, index) -> !fir.shape<2>
+ %15:2 = hlfir.declare %13(%14) {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eidev"} : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.ref<!fir.array<?x?xi32>>)
+ cuf.free %15#1 : !fir.ref<!fir.array<?x?xi32>> {data_attr = #cuf.cuda<device>}
+ return
+}
+
+// CHECK-LABEL: func.func @_QPsub3
+// CHECK: %[[N:.*]] = arith.select
+// CHECK: %[[M:.*]] = arith.select
+// CHECK: %[[NBELEM:.*]] = arith.muli %[[N]], %[[M]] : index
+// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %c4{{.*}} : index
+// CHECK: %[[CONV_BYTES:.*]] = fir.convert %[[BYTES]] : (index) -> i64
+// CHECK: %{{.*}} = fir.call @_FortranACUFMemAlloc(%[[CONV_BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: fir.call @_FortranACUFMemFree
+
+} // end module
diff --git a/flang/test/Fir/CUDA/cuda-allocate.fir b/flang/test/Fir/CUDA/cuda-allocate.fir
index 65c68bb69301af..d68ff894d5af5a 100644
--- a/flang/test/Fir/CUDA/cuda-allocate.fir
+++ b/flang/test/Fir/CUDA/cuda-allocate.fir
@@ -26,17 +26,6 @@ func.func @_QPsub1() {
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DECL_DESC]]#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFFreeDesciptor(%[[BOX_NONE]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<i8>, i32) -> none
-// Check operations that should not be transformed yet.
-func.func @_QPsub2() {
- %0 = cuf.alloc !fir.array<10xf32> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QMcuda_varFcuda_alloc_freeEa"} -> !fir.ref<!fir.array<10xf32>>
- cuf.free %0 : !fir.ref<!fir.array<10xf32>> {data_attr = #cuf.cuda<device>}
- return
-}
-
-// CHECK-LABEL: func.func @_QPsub2()
-// CHECK: cuf.alloc !fir.array<10xf32>
-// CHECK: cuf.free %{{.*}} : !fir.ref<!fir.array<10xf32>>
-
fir.global @_QMmod1Ea {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xf32>>> {
%0 = fir.zero_bits !fir.heap<!fir.array<?xf32>>
%c0 = arith.constant 0 : index
|
wangzpgi
approved these changes
Sep 25, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This patch adds more conversion of cuf.alloc and cuf.free for scalars, constant size arrays and dynamic size arrays