Skip to content

Commit b15bd3f

Browse files
authored
[flang][cuda] Add global constructor for allocators registration (#109854)
This pass creates the constructor function to call the allocator registration and adds it to the global_ctors.
1 parent b62075e commit b15bd3f

File tree

5 files changed

+96
-0
lines changed

5 files changed

+96
-0
lines changed

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ namespace fir {
3939
#define GEN_PASS_DECL_ASSUMEDRANKOPCONVERSION
4040
#define GEN_PASS_DECL_CHARACTERCONVERSION
4141
#define GEN_PASS_DECL_CFGCONVERSION
42+
#define GEN_PASS_DECL_CUFADDCONSTRUCTOR
4243
#define GEN_PASS_DECL_CUFIMPLICITDEVICEGLOBAL
4344
#define GEN_PASS_DECL_CUFOPCONVERSION
4445
#define GEN_PASS_DECL_EXTERNALNAMECONVERSION

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,4 +436,11 @@ def CufImplicitDeviceGlobal :
436436
];
437437
}
438438

439+
def CUFAddConstructor : Pass<"cuf-add-constructor", "mlir::ModuleOp"> {
440+
let summary = "Add constructor to register CUDA Fortran allocators";
441+
let dependentDialects = [
442+
"mlir::func::FuncDialect"
443+
];
444+
}
445+
439446
#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES

flang/lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_flang_library(FIRTransforms
99
CompilerGeneratedNames.cpp
1010
ConstantArgumentGlobalisation.cpp
1111
ControlFlowConverter.cpp
12+
CUFAddConstructor.cpp
1213
CufImplicitDeviceGlobal.cpp
1314
CufOpConversion.cpp
1415
ArrayValueCopy.cpp
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
//===-- CUFAddConstructor.cpp ---------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Optimizer/Builder/FIRBuilder.h"
10+
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
11+
#include "flang/Optimizer/Dialect/FIRAttr.h"
12+
#include "flang/Optimizer/Dialect/FIRDialect.h"
13+
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
14+
#include "flang/Runtime/entry-names.h"
15+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16+
#include "mlir/Pass/Pass.h"
17+
#include "llvm/ADT/SmallVector.h"
18+
19+
namespace fir {
20+
#define GEN_PASS_DEF_CUFADDCONSTRUCTOR
21+
#include "flang/Optimizer/Transforms/Passes.h.inc"
22+
} // namespace fir
23+
24+
namespace {
25+
26+
static constexpr llvm::StringRef cudaFortranCtorName{
27+
"__cudaFortranConstructor"};
28+
29+
struct CUFAddConstructor
30+
: public fir::impl::CUFAddConstructorBase<CUFAddConstructor> {
31+
32+
void runOnOperation() override {
33+
mlir::ModuleOp mod = getOperation();
34+
mlir::OpBuilder builder{mod.getBodyRegion()};
35+
builder.setInsertionPointToEnd(mod.getBody());
36+
mlir::Location loc = mod.getLoc();
37+
auto *ctx = mod.getContext();
38+
auto voidTy = mlir::LLVM::LLVMVoidType::get(ctx);
39+
auto funcTy =
40+
mlir::LLVM::LLVMFunctionType::get(voidTy, {}, /*isVarArg=*/false);
41+
42+
// Symbol reference to CUFRegisterAllocator.
43+
builder.setInsertionPointToEnd(mod.getBody());
44+
auto registerFuncOp = builder.create<mlir::LLVM::LLVMFuncOp>(
45+
loc, RTNAME_STRING(CUFRegisterAllocator), funcTy);
46+
registerFuncOp.setVisibility(mlir::SymbolTable::Visibility::Private);
47+
auto cufRegisterAllocatorRef = mlir::SymbolRefAttr::get(
48+
mod.getContext(), RTNAME_STRING(CUFRegisterAllocator));
49+
builder.setInsertionPointToEnd(mod.getBody());
50+
51+
// Create the constructor function that cal CUFRegisterAllocator.
52+
builder.setInsertionPointToEnd(mod.getBody());
53+
auto func = builder.create<mlir::LLVM::LLVMFuncOp>(loc, cudaFortranCtorName,
54+
funcTy);
55+
func.setLinkage(mlir::LLVM::Linkage::Internal);
56+
builder.setInsertionPointToStart(func.addEntryBlock(builder));
57+
builder.create<mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);
58+
builder.create<mlir::LLVM::ReturnOp>(loc, mlir::ValueRange{});
59+
60+
// Create the llvm.global_ctor with the function.
61+
// TODO: We might want to have a utility that retrieve it if already created
62+
// and adds new functions.
63+
builder.setInsertionPointToEnd(mod.getBody());
64+
llvm::SmallVector<mlir::Attribute> funcs;
65+
funcs.push_back(
66+
mlir::FlatSymbolRefAttr::get(mod.getContext(), func.getSymName()));
67+
llvm::SmallVector<int> priorities;
68+
priorities.push_back(0);
69+
builder.create<mlir::LLVM::GlobalCtorsOp>(
70+
mod.getLoc(), builder.getArrayAttr(funcs),
71+
builder.getI32ArrayAttr(priorities));
72+
}
73+
};
74+
75+
} // end anonymous namespace
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
! RUN: bbc -fcuda -emit-hlfir %s -o - | fir-opt --cuf-add-constructor | FileCheck %s
2+
3+
program main
4+
real, device :: ahost(10)
5+
end
6+
7+
! CHECK: llvm.func @_FortranACUFRegisterAllocator() attributes {sym_visibility = "private"}
8+
! CHECK-LABEL: llvm.func internal @__cudaFortranConstructor() {
9+
! CHECK: llvm.call @_FortranACUFRegisterAllocator() : () -> ()
10+
! CHECK: llvm.return
11+
! CHECK: }
12+
! CHECK: llvm.mlir.global_ctors {ctors = [@__cudaFortranConstructor], priorities = [0 : i32]}

0 commit comments

Comments
 (0)