Skip to content

Commit 0eb5c9d

Browse files
authored
[flang][cuda] Copying device globals in the gpu module (llvm#113955)
1 parent e873b41 commit 0eb5c9d

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "flang/Optimizer/Dialect/FIRDialect.h"
1212
#include "flang/Optimizer/Dialect/FIROps.h"
1313
#include "flang/Optimizer/HLFIR/HLFIROps.h"
14+
#include "flang/Optimizer/Transforms/CUFCommon.h"
1415
#include "flang/Runtime/CUDA/common.h"
1516
#include "flang/Runtime/allocatable.h"
1617
#include "mlir/IR/SymbolTable.h"
@@ -58,6 +59,32 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
5859
prepareImplicitDeviceGlobals(funcOp, symTable);
5960
return mlir::WalkResult::advance();
6061
});
62+
63+
// Copying the device global variable into the gpu module
64+
mlir::SymbolTable parentSymTable(mod);
65+
auto gpuMod =
66+
parentSymTable.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
67+
if (gpuMod) {
68+
mlir::SymbolTable gpuSymTable(gpuMod);
69+
for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
70+
auto attr = globalOp.getDataAttrAttr();
71+
if (!attr)
72+
continue;
73+
switch (attr.getValue()) {
74+
case cuf::DataAttribute::Device:
75+
case cuf::DataAttribute::Constant:
76+
case cuf::DataAttribute::Managed: {
77+
auto globalName{globalOp.getSymbol().getValue()};
78+
if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
79+
break;
80+
}
81+
gpuSymTable.insert(globalOp->clone());
82+
} break;
83+
default:
84+
break;
85+
}
86+
}
87+
}
6188
}
6289
};
6390
} // namespace
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
2+
// RUN: fir-opt --split-input-file --cuf-device-global %s | FileCheck %s
3+
4+
5+
module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module} {
6+
fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda<device>} : !fir.array<5xi32>
7+
8+
gpu.module @cuda_device_mod [#nvvm.target] {
9+
}
10+
}
11+
12+
// CHECK: gpu.module @cuda_device_mod [#nvvm.target]
13+
// CHECK-NEXT: fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda<device>} : !fir.array<5xi32>

0 commit comments

Comments
 (0)