Skip to content

[flang][cuda] Support malloc and free conversion in gpu module #116112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flang/lib/Optimizer/CodeGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ add_flang_library(FIRCodeGen
FIRSupport
MLIRComplexToLLVM
MLIRComplexToStandard
MLIRGPUDialect
MLIRMathToFuncs
MLIRMathToLLVM
MLIRMathToLibm
Expand Down
49 changes: 36 additions & 13 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
Expand Down Expand Up @@ -920,17 +921,19 @@ struct EmboxCharOpConversion : public fir::FIROpConversion<fir::EmboxCharOp> {
};
} // namespace

/// Return the LLVMFuncOp corresponding to the standard malloc call.
template <typename ModuleOp>
static mlir::SymbolRefAttr
getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
getMallocInModule(ModuleOp mod, fir::AllocMemOp op,
mlir::ConversionPatternRewriter &rewriter) {
static constexpr char mallocName[] = "malloc";
auto module = op->getParentOfType<mlir::ModuleOp>();
if (auto mallocFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
if (auto mallocFunc =
mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName))
return mlir::SymbolRefAttr::get(mallocFunc);
if (auto userMalloc = module.lookupSymbol<mlir::func::FuncOp>(mallocName))
if (auto userMalloc =
mod.template lookupSymbol<mlir::func::FuncOp>(mallocName))
return mlir::SymbolRefAttr::get(userMalloc);
mlir::OpBuilder moduleBuilder(
op->getParentOfType<mlir::ModuleOp>().getBodyRegion());

mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
auto indexType = mlir::IntegerType::get(op.getContext(), 64);
auto mallocDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
op.getLoc(), mallocName,
Expand All @@ -940,6 +943,15 @@ getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
return mlir::SymbolRefAttr::get(mallocDecl);
}

/// Return the LLVMFuncOp corresponding to the standard malloc call.
static mlir::SymbolRefAttr
getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) {
if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
return getMallocInModule(mod, op, rewriter);
auto mod = op->getParentOfType<mlir::ModuleOp>();
return getMallocInModule(mod, op, rewriter);
}

/// Helper function for generating the LLVM IR that computes the distance
/// in bytes between adjacent elements pointed to by a pointer
/// of type \p ptrTy. The result is returned as a value of \p idxTy integer
Expand Down Expand Up @@ -1016,18 +1028,20 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
} // namespace

/// Return the LLVMFuncOp corresponding to the standard free call.
static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
mlir::ConversionPatternRewriter &rewriter) {
template <typename ModuleOp>
static mlir::SymbolRefAttr
getFreeInModule(ModuleOp mod, fir::FreeMemOp op,
mlir::ConversionPatternRewriter &rewriter) {
static constexpr char freeName[] = "free";
auto module = op->getParentOfType<mlir::ModuleOp>();
// Check if free already defined in the module.
if (auto freeFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
if (auto freeFunc =
mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName))
return mlir::SymbolRefAttr::get(freeFunc);
if (auto freeDefinedByUser =
module.lookupSymbol<mlir::func::FuncOp>(freeName))
mod.template lookupSymbol<mlir::func::FuncOp>(freeName))
return mlir::SymbolRefAttr::get(freeDefinedByUser);
// Create llvm declaration for free.
mlir::OpBuilder moduleBuilder(module.getBodyRegion());
mlir::OpBuilder moduleBuilder(mod.getBodyRegion());
auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext());
auto freeDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), freeName,
Expand All @@ -1037,6 +1051,14 @@ static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
return mlir::SymbolRefAttr::get(freeDecl);
}

static mlir::SymbolRefAttr getFree(fir::FreeMemOp op,
mlir::ConversionPatternRewriter &rewriter) {
if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>())
return getFreeInModule(mod, op, rewriter);
auto mod = op->getParentOfType<mlir::ModuleOp>();
return getFreeInModule(mod, op, rewriter);
}

static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
unsigned result = 1;
for (auto eleTy =
Expand Down Expand Up @@ -3730,6 +3752,7 @@ class FIRToLLVMLowering
mlir::configureOpenMPToLLVMConversionLegality(target, typeConverter);
target.addLegalDialect<mlir::omp::OpenMPDialect>();
target.addLegalDialect<mlir::acc::OpenACCDialect>();
target.addLegalDialect<mlir::gpu::GPUDialect>();

// required NOPs for applying a full conversion
target.addLegalOp<mlir::ModuleOp>();
Expand Down
16 changes: 16 additions & 0 deletions flang/test/Fir/convert-to-llvm.fir
Original file line number Diff line number Diff line change
Expand Up @@ -2776,3 +2776,19 @@ func.func @coordinate_array_unknown_size_1d(%arg0: !fir.ptr<!fir.array<? x i32>>
fir.global common @c_(dense<0> : vector<4294967296xi8>) : !fir.array<4294967296xi8>

// CHECK: llvm.mlir.global common @c_(dense<0> : vector<4294967296xi8>) {addr_space = 0 : i32} : !llvm.array<4294967296 x i8>

// -----

gpu.module @cuda_device_mod {
gpu.func @test_alloc_and_freemem_one() {
%z0 = fir.allocmem i32
fir.freemem %z0 : !fir.heap<i32>
gpu.return
}
}

// CHECK: gpu.module @cuda_device_mod {
// CHECK: llvm.func @free(!llvm.ptr)
// CHECK: llvm.func @malloc(i64) -> !llvm.ptr
// CHECK: llvm.call @malloc
// CHECK: lvm.call @free
Loading