Skip to content

Commit 2e89e6b

Browse files
authored
[flang][cuda] Flag globals used in device function (#109460)
1 parent a9352a0 commit 2e89e6b

File tree

5 files changed

+123
-0
lines changed

5 files changed

+123
-0
lines changed

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ namespace fir {
3939
#define GEN_PASS_DECL_ASSUMEDRANKOPCONVERSION
4040
#define GEN_PASS_DECL_CHARACTERCONVERSION
4141
#define GEN_PASS_DECL_CFGCONVERSION
42+
#define GEN_PASS_DECL_CUFIMPLICITDEVICEGLOBAL
4243
#define GEN_PASS_DECL_CUFOPCONVERSION
4344
#define GEN_PASS_DECL_EXTERNALNAMECONVERSION
4445
#define GEN_PASS_DECL_MEMREFDATAFLOWOPT

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,4 +428,12 @@ def CufOpConversion : Pass<"cuf-convert", "mlir::ModuleOp"> {
428428
];
429429
}
430430

431+
def CufImplicitDeviceGlobal :
432+
Pass<"cuf-implicit-device-global", "mlir::ModuleOp"> {
433+
let summary = "Flag globals used in device function with data attribute";
434+
let dependentDialects = [
435+
"cuf::CUFDialect"
436+
];
437+
}
438+
431439
#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES

flang/lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_flang_library(FIRTransforms
99
CompilerGeneratedNames.cpp
1010
ConstantArgumentGlobalisation.cpp
1111
ControlFlowConverter.cpp
12+
CufImplicitDeviceGlobal.cpp
1213
CufOpConversion.cpp
1314
ArrayValueCopy.cpp
1415
ExternalNameConversion.cpp
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//===-- CufOpConversion.cpp -----------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Common/Fortran.h"
10+
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
11+
#include "flang/Optimizer/Dialect/FIRDialect.h"
12+
#include "flang/Optimizer/Dialect/FIROps.h"
13+
#include "flang/Optimizer/HLFIR/HLFIROps.h"
14+
#include "flang/Runtime/CUDA/common.h"
15+
#include "flang/Runtime/allocatable.h"
16+
#include "mlir/IR/SymbolTable.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include "mlir/Transforms/DialectConversion.h"
19+
20+
namespace fir {
21+
#define GEN_PASS_DEF_CUFIMPLICITDEVICEGLOBAL
22+
#include "flang/Optimizer/Transforms/Passes.h.inc"
23+
} // namespace fir
24+
25+
namespace {
26+
27+
static void prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp,
28+
mlir::SymbolTable &symbolTable,
29+
bool onlyConstant = true) {
30+
auto cudaProcAttr{
31+
funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName())};
32+
if (!cudaProcAttr || cudaProcAttr.getValue() == cuf::ProcAttribute::Host)
33+
return;
34+
for (auto addrOfOp : funcOp.getBody().getOps<fir::AddrOfOp>()) {
35+
if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
36+
addrOfOp.getSymbol().getRootReference().getValue())) {
37+
bool isCandidate{(onlyConstant ? globalOp.getConstant() : true) &&
38+
!globalOp.getDataAttr()};
39+
if (isCandidate)
40+
globalOp.setDataAttrAttr(cuf::DataAttributeAttr::get(
41+
funcOp.getContext(), globalOp.getConstant()
42+
? cuf::DataAttribute::Constant
43+
: cuf::DataAttribute::Device));
44+
}
45+
}
46+
}
47+
48+
class CufImplicitDeviceGlobal
49+
: public fir::impl::CufImplicitDeviceGlobalBase<CufImplicitDeviceGlobal> {
50+
public:
51+
void runOnOperation() override {
52+
mlir::Operation *op = getOperation();
53+
mlir::ModuleOp mod = mlir::dyn_cast<mlir::ModuleOp>(op);
54+
if (!mod)
55+
return signalPassFailure();
56+
57+
mlir::SymbolTable symTable(mod);
58+
mod.walk([&](mlir::func::FuncOp funcOp) {
59+
prepareImplicitDeviceGlobals(funcOp, symTable);
60+
return mlir::WalkResult::advance();
61+
});
62+
}
63+
};
64+
} // namespace
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: fir-opt --split-input-file --cuf-implicit-device-global %s | FileCheck %s
2+
3+
// Test that global used in device function are flagged with the correct
4+
// attribute.
5+
6+
func.func @_QMdataPsetvalue() attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
7+
%c6_i32 = arith.constant 6 : i32
8+
%21 = fir.address_of(@_QQclX6995815537abaf90e86ce166af128f3a) : !fir.ref<!fir.char<1,32>>
9+
%22 = fir.convert %21 : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
10+
%c14_i32 = arith.constant 14 : i32
11+
%23 = fir.call @_FortranAioBeginExternalListOutput(%c6_i32, %22, %c14_i32) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
12+
return
13+
}
14+
15+
func.func private @_FortranAioBeginExternalListOutput(i32, !fir.ref<i8>, i32) -> !fir.ref<i8> attributes {fir.io, fir.runtime}
16+
fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a constant : !fir.char<1,32> {
17+
%0 = fir.string_lit "cuda-implicit-device-global.fir\00"(32) : !fir.char<1,32>
18+
fir.has_value %0 : !fir.char<1,32>
19+
}
20+
21+
// CHECK-LABEL: func.func @_QMdataPsetvalue() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
22+
23+
// CHECK: %[[GLOBAL:.*]] = fir.address_of(@_QQcl[[SYMBOL:.*]]) : !fir.ref<!fir.char<1,32>>
24+
// CHECK: %[[CONV:.*]] = fir.convert %[[GLOBAL]] : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
25+
// CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
26+
// CHECK: fir.global linkonce @_QQcl[[SYMBOL]] {data_attr = #cuf.cuda<constant>} constant : !fir.char<1,32>
27+
28+
// -----
29+
30+
func.func @_QMdataPsetvalue() {
31+
%c6_i32 = arith.constant 6 : i32
32+
%21 = fir.address_of(@_QQclX6995815537abaf90e86ce166af128f3a) : !fir.ref<!fir.char<1,32>>
33+
%22 = fir.convert %21 : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
34+
%c14_i32 = arith.constant 14 : i32
35+
%23 = fir.call @_FortranAioBeginExternalListOutput(%c6_i32, %22, %c14_i32) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
36+
return
37+
}
38+
39+
func.func private @_FortranAioBeginExternalListOutput(i32, !fir.ref<i8>, i32) -> !fir.ref<i8> attributes {fir.io, fir.runtime}
40+
fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a constant : !fir.char<1,32> {
41+
%0 = fir.string_lit "cuda-implicit-device-global.fir\00"(32) : !fir.char<1,32>
42+
fir.has_value %0 : !fir.char<1,32>
43+
}
44+
45+
// CHECK-LABEL: func.func @_QMdataPsetvalue()
46+
// CHECK: %[[GLOBAL:.*]] = fir.address_of(@_QQcl[[SYMBOL:.*]]) : !fir.ref<!fir.char<1,32>>
47+
// CHECK: %[[CONV:.*]] = fir.convert %[[GLOBAL]] : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
48+
// CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
49+
// CHECK: fir.global linkonce @_QQcl[[SYMBOL]] constant : !fir.char<1,32>

0 commit comments

Comments
 (0)