Skip to content

Commit 956d0dd

Browse files
authored
[flang][cuda] Support builtin global in device global pass (#119626)
1 parent be4a183 commit 956d0dd

File tree

2 files changed

+147
-3
lines changed

2 files changed

+147
-3
lines changed

flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "flang/Optimizer/Dialect/FIRDialect.h"
1212
#include "flang/Optimizer/Dialect/FIROps.h"
1313
#include "flang/Optimizer/HLFIR/HLFIROps.h"
14+
#include "flang/Optimizer/Support/InternalNames.h"
1415
#include "flang/Optimizer/Transforms/CUFCommon.h"
1516
#include "flang/Runtime/CUDA/common.h"
1617
#include "flang/Runtime/allocatable.h"
@@ -27,6 +28,8 @@ namespace fir {
2728

2829
namespace {
2930

31+
static constexpr llvm::StringRef builtinPrefix = "_QM__fortran_builtins";
32+
3033
static void processAddrOfOp(fir::AddrOfOp addrOfOp,
3134
mlir::SymbolTable &symbolTable,
3235
llvm::DenseSet<fir::GlobalOp> &candidates) {
@@ -35,22 +38,46 @@ static void processAddrOfOp(fir::AddrOfOp addrOfOp,
3538
// TO DO: limit candidates to non-scalars. Scalars appear to have been
3639
// folded in already.
3740
if (globalOp.getConstant()) {
41+
// Limit recursion to builtin global for now.
42+
if (globalOp.getSymName().starts_with(builtinPrefix)) {
43+
globalOp.walk([&](fir::AddrOfOp op) {
44+
processAddrOfOp(op, symbolTable, candidates);
45+
});
46+
}
3847
candidates.insert(globalOp);
3948
}
4049
}
4150
}
4251

52+
static void processEmboxOp(fir::EmboxOp emboxOp, mlir::SymbolTable &symbolTable,
53+
llvm::DenseSet<fir::GlobalOp> &candidates) {
54+
if (auto recTy = mlir::dyn_cast<fir::RecordType>(
55+
fir::unwrapRefType(emboxOp.getMemref().getType())))
56+
// Only look at builtin record type.
57+
if (recTy.getName().starts_with(builtinPrefix))
58+
if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
59+
fir::NameUniquer::getTypeDescriptorName(recTy.getName()))) {
60+
if (!candidates.contains(globalOp)) {
61+
globalOp.walk([&](fir::AddrOfOp op) {
62+
processAddrOfOp(op, symbolTable, candidates);
63+
});
64+
candidates.insert(globalOp);
65+
}
66+
}
67+
}
68+
4369
static void
4470
prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp,
4571
mlir::SymbolTable &symbolTable,
4672
llvm::DenseSet<fir::GlobalOp> &candidates) {
47-
4873
auto cudaProcAttr{
4974
funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName())};
5075
if (cudaProcAttr && cudaProcAttr.getValue() != cuf::ProcAttribute::Host) {
51-
funcOp.walk([&](fir::AddrOfOp addrOfOp) {
52-
processAddrOfOp(addrOfOp, symbolTable, candidates);
76+
funcOp.walk([&](fir::AddrOfOp op) {
77+
processAddrOfOp(op, symbolTable, candidates);
5378
});
79+
funcOp.walk(
80+
[&](fir::EmboxOp op) { processEmboxOp(op, symbolTable, candidates); });
5481
}
5582
}
5683

0 commit comments

Comments
 (0)