Skip to content

Commit e5092c3

Browse files
authored
[flang][cuda] Support malloc and free conversion in gpu module (#116112)
1 parent 1f0e0da commit e5092c3

File tree

3 files changed

+53
-13
lines changed

3 files changed

+53
-13
lines changed

flang/lib/Optimizer/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ add_flang_library(FIRCodeGen
2323
FIRSupport
2424
MLIRComplexToLLVM
2525
MLIRComplexToStandard
26+
MLIRGPUDialect
2627
MLIRMathToFuncs
2728
MLIRMathToLLVM
2829
MLIRMathToLibm

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
4242
#include "mlir/Dialect/Arith/IR/Arith.h"
4343
#include "mlir/Dialect/DLTI/DLTI.h"
44+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
4445
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
4546
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
4647
#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
@@ -920,17 +921,19 @@ struct EmboxCharOpConversion : public fir::FIROpConversion<fir::EmboxCharOp> {
920921
};
921922
} // namespace
922923

923-
/// Return the LLVMFuncOp corresponding to the standard malloc call.
924+
template <typename ModuleOp>
924925
static mlir::SymbolRefAttr
925-
getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
926+
getMallocInModule(ModuleOp mod, fir::AllocMemOp op,
927+
mlir::ConversionPatternRewriter &rewriter) {
926928
static constexpr char mallocName[] = "malloc";
927-
auto module = op->getParentOfType<mlir::ModuleOp>();
928-
if (auto mallocFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
929+
if (auto mallocFunc =
930+
mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
929931
return mlir::SymbolRefAttr::get(mallocFunc);
930-
if (auto userMalloc = module.lookupSymbol<mlir::func::FuncOp>(mallocName))
932+
if (auto userMalloc =
933+
mod.template lookupSymbol<mlir::func::FuncOp>(mallocName))
931934
return mlir::SymbolRefAttr::get(userMalloc);
932-
mlir::OpBuilder moduleBuilder(
933-
op->getParentOfType<mlir::ModuleOp>().getBodyRegion());
935+
936+
mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
934937
auto indexType = mlir::IntegerType::get(op.getContext(), 64);
935938
auto mallocDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
936939
op.getLoc(), mallocName,
@@ -940,6 +943,15 @@ getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
940943
return mlir::SymbolRefAttr::get(mallocDecl);
941944
}
942945

946+
/// Return the LLVMFuncOp corresponding to the standard malloc call.
947+
static mlir::SymbolRefAttr
948+
getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
949+
if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
950+
return getMallocInModule(mod, op, rewriter);
951+
auto mod = op->getParentOfType<mlir::ModuleOp>();
952+
return getMallocInModule(mod, op, rewriter);
953+
}
954+
943955
/// Helper function for generating the LLVM IR that computes the distance
944956
/// in bytes between adjacent elements pointed to by a pointer
945957
/// of type \p ptrTy. The result is returned as a value of \p idxTy integer
@@ -1016,18 +1028,20 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
10161028
} // namespace
10171029

10181030
/// Return the LLVMFuncOp corresponding to the standard free call.
1019-
static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
1020-
mlir::ConversionPatternRewriter &rewriter) {
1031+
template <typename ModuleOp>
1032+
static mlir::SymbolRefAttr
1033+
getFreeInModule(ModuleOp mod, fir::FreeMemOp op,
1034+
mlir::ConversionPatternRewriter &rewriter) {
10211035
static constexpr char freeName[] = "free";
1022-
auto module = op->getParentOfType<mlir::ModuleOp>();
10231036
// Check if free already defined in the module.
1024-
if (auto freeFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
1037+
if (auto freeFunc =
1038+
mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
10251039
return mlir::SymbolRefAttr::get(freeFunc);
10261040
if (auto freeDefinedByUser =
1027-
module.lookupSymbol<mlir::func::FuncOp>(freeName))
1041+
mod.template lookupSymbol<mlir::func::FuncOp>(freeName))
10281042
return mlir::SymbolRefAttr::get(freeDefinedByUser);
10291043
// Create llvm declaration for free.
1030-
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
1044+
mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
10311045
auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext());
10321046
auto freeDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
10331047
rewriter.getUnknownLoc(), freeName,
@@ -1037,6 +1051,14 @@ static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
10371051
return mlir::SymbolRefAttr::get(freeDecl);
10381052
}
10391053

1054+
static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
1055+
mlir::ConversionPatternRewriter &rewriter) {
1056+
if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
1057+
return getFreeInModule(mod, op, rewriter);
1058+
auto mod = op->getParentOfType<mlir::ModuleOp>();
1059+
return getFreeInModule(mod, op, rewriter);
1060+
}
1061+
10401062
static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
10411063
unsigned result = 1;
10421064
for (auto eleTy =
@@ -3730,6 +3752,7 @@ class FIRToLLVMLowering
37303752
mlir::configureOpenMPToLLVMConversionLegality(target, typeConverter);
37313753
target.addLegalDialect<mlir::omp::OpenMPDialect>();
37323754
target.addLegalDialect<mlir::acc::OpenACCDialect>();
3755+
target.addLegalDialect<mlir::gpu::GPUDialect>();
37333756

37343757
// required NOPs for applying a full conversion
37353758
target.addLegalOp<mlir::ModuleOp>();

flang/test/Fir/convert-to-llvm.fir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2776,3 +2776,19 @@ func.func @coordinate_array_unknown_size_1d(%arg0: !fir.ptr<!fir.array<? x i32>>
27762776
fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8>
27772777

27782778
// CHECK: llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>
2779+
2780+
// -----
2781+
2782+
gpu.module @cuda_device_mod {
2783+
gpu.func @test_alloc_and_freemem_one() {
2784+
%z0 = fir.allocmem i32
2785+
fir.freemem %z0 : !fir.heap<i32>
2786+
gpu.return
2787+
}
2788+
}
2789+
2790+
// CHECK: gpu.module @cuda_device_mod {
2791+
// CHECK: llvm.func @free(!llvm.ptr)
2792+
// CHECK: llvm.func @malloc(i64) -> !llvm.ptr
2793+
// CHECK: llvm.call @malloc
2794+
// CHECK: lvm.call @free

0 commit comments

Comments
 (0)