11
11
#include " flang/Optimizer/Dialect/FIRDialect.h"
12
12
#include " flang/Optimizer/Dialect/FIROps.h"
13
13
#include " flang/Optimizer/HLFIR/HLFIROps.h"
14
+ #include " flang/Optimizer/Support/InternalNames.h"
14
15
#include " flang/Optimizer/Transforms/CUFCommon.h"
15
16
#include " flang/Runtime/CUDA/common.h"
16
17
#include " flang/Runtime/allocatable.h"
@@ -27,6 +28,8 @@ namespace fir {
27
28
28
29
namespace {
29
30
31
+ static constexpr llvm::StringRef builtinPrefix = " _QM__fortran_builtins" ;
32
+
30
33
static void processAddrOfOp (fir::AddrOfOp addrOfOp,
31
34
mlir::SymbolTable &symbolTable,
32
35
llvm::DenseSet<fir::GlobalOp> &candidates) {
@@ -35,22 +38,46 @@ static void processAddrOfOp(fir::AddrOfOp addrOfOp,
35
38
// TO DO: limit candidates to non-scalars. Scalars appear to have been
36
39
// folded in already.
37
40
if (globalOp.getConstant ()) {
41
+ // Limit recursion to builtin global for now.
42
+ if (globalOp.getSymName ().starts_with (builtinPrefix)) {
43
+ globalOp.walk ([&](fir::AddrOfOp op) {
44
+ processAddrOfOp (op, symbolTable, candidates);
45
+ });
46
+ }
38
47
candidates.insert (globalOp);
39
48
}
40
49
}
41
50
}
42
51
52
+ static void processEmboxOp (fir::EmboxOp emboxOp, mlir::SymbolTable &symbolTable,
53
+ llvm::DenseSet<fir::GlobalOp> &candidates) {
54
+ if (auto recTy = mlir::dyn_cast<fir::RecordType>(
55
+ fir::unwrapRefType (emboxOp.getMemref ().getType ())))
56
+ // Only look at builtin record type.
57
+ if (recTy.getName ().starts_with (builtinPrefix))
58
+ if (auto globalOp = symbolTable.lookup <fir::GlobalOp>(
59
+ fir::NameUniquer::getTypeDescriptorName (recTy.getName ()))) {
60
+ if (!candidates.contains (globalOp)) {
61
+ globalOp.walk ([&](fir::AddrOfOp op) {
62
+ processAddrOfOp (op, symbolTable, candidates);
63
+ });
64
+ candidates.insert (globalOp);
65
+ }
66
+ }
67
+ }
68
+
43
69
static void
44
70
prepareImplicitDeviceGlobals (mlir::func::FuncOp funcOp,
45
71
mlir::SymbolTable &symbolTable,
46
72
llvm::DenseSet<fir::GlobalOp> &candidates) {
47
-
48
73
auto cudaProcAttr{
49
74
funcOp->getAttrOfType <cuf::ProcAttributeAttr>(cuf::getProcAttrName ())};
50
75
if (cudaProcAttr && cudaProcAttr.getValue () != cuf::ProcAttribute::Host) {
51
- funcOp.walk ([&](fir::AddrOfOp addrOfOp ) {
52
- processAddrOfOp (addrOfOp , symbolTable, candidates);
76
+ funcOp.walk ([&](fir::AddrOfOp op ) {
77
+ processAddrOfOp (op , symbolTable, candidates);
53
78
});
79
+ funcOp.walk (
80
+ [&](fir::EmboxOp op) { processEmboxOp (op, symbolTable, candidates); });
54
81
}
55
82
}
56
83
0 commit comments