Skip to content

Commit f4cecfe

Browse files
authored
[flang][cuda] Bring PARAMETER arrays into the GPU module (#146416)
1 parent 56739f5 commit f4cecfe

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,16 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
113113
return signalPassFailure();
114114
mlir::SymbolTable gpuSymTable(gpuMod);
115115
for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
116-
if (cuf::isRegisteredDeviceGlobal(globalOp))
116+
if (cuf::isRegisteredDeviceGlobal(globalOp)) {
117117
candidates.insert(globalOp);
118+
} else if (globalOp.getConstant() &&
119+
mlir::isa<fir::SequenceType>(
120+
fir::unwrapRefType(globalOp.resultType()))) {
121+
mlir::Attribute initAttr =
122+
globalOp.getInitVal().value_or(mlir::Attribute());
123+
if (initAttr && mlir::dyn_cast<mlir::DenseElementsAttr>(initAttr))
124+
candidates.insert(globalOp);
125+
}
118126
}
119127
for (auto globalOp : candidates) {
120128
auto globalName{globalOp.getSymbol().getValue()};

flang/test/Fir/CUDA/cuda-device-global.f90

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,16 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.conta
1111

1212
// CHECK: gpu.module @cuda_device_mo
1313
// CHECK-NEXT: fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda<device>} : !fir.array<5xi32>
14+
15+
// -----
16+
17+
module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module} {
18+
fir.global @_QMm1ECb(dense<[90, 100, 110]> : tensor<3xi32>) constant : !fir.array<3xi32>
19+
fir.global @_QMm2ECc(dense<[100, 200, 300]> : tensor<3xi32>) constant : !fir.array<3xi32>
20+
}
21+
22+
// CHECK: fir.global @_QMm1ECb
23+
// CHECK: fir.global @_QMm2ECc
24+
// CHECK: gpu.module @cuda_device_mod
25+
// CHECK-DAG: fir.global @_QMm2ECc
26+
// CHECK-DAG: fir.global @_QMm1ECb

0 commit comments

Comments
 (0)