Skip to content

Commit 9165848

Browse files
authored
[flang][cuda] Sync global descriptor when nullifying pointer (#121595)
1 parent 78f0447 commit 9165848

File tree

5 files changed

+32
-18
lines changed

5 files changed

+32
-18
lines changed

flang/include/flang/Optimizer/Builder/CUFCommon.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515

1616
static constexpr llvm::StringRef cudaDeviceModuleName = "cuda_device_mod";
1717

18+
namespace fir {
19+
class FirOpBuilder;
20+
} // namespace fir
21+
1822
namespace cuf {
1923

2024
/// Retrieve or create the CUDA Fortran GPU module in the given \p mod.
@@ -24,6 +28,8 @@ mlir::gpu::GPUModuleOp getOrCreateGPUModule(mlir::ModuleOp mod,
2428
bool isInCUDADeviceContext(mlir::Operation *op);
2529
bool isRegisteredDeviceGlobal(fir::GlobalOp op);
2630

31+
void genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder);
32+
2733
} // namespace cuf
2834

2935
#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_

flang/lib/Lower/Allocatable.cpp

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,35 +1088,20 @@ bool Fortran::lower::isArraySectionWithoutVectorSubscript(
10881088
!Fortran::evaluate::HasVectorSubscript(expr);
10891089
}
10901090

1091-
static void genCUFPointerSync(const mlir::Value box,
1092-
fir::FirOpBuilder &builder) {
1093-
if (auto declareOp = box.getDefiningOp<hlfir::DeclareOp>()) {
1094-
if (auto addrOfOp = declareOp.getMemref().getDefiningOp<fir::AddrOfOp>()) {
1095-
auto mod = addrOfOp->getParentOfType<mlir::ModuleOp>();
1096-
if (auto globalOp =
1097-
mod.lookupSymbol<fir::GlobalOp>(addrOfOp.getSymbol())) {
1098-
if (cuf::isRegisteredDeviceGlobal(globalOp)) {
1099-
builder.create<cuf::SyncDescriptorOp>(box.getLoc(),
1100-
addrOfOp.getSymbol());
1101-
}
1102-
}
1103-
}
1104-
}
1105-
}
1106-
11071091
void Fortran::lower::associateMutableBox(
11081092
Fortran::lower::AbstractConverter &converter, mlir::Location loc,
11091093
const fir::MutableBoxValue &box, const Fortran::lower::SomeExpr &source,
11101094
mlir::ValueRange lbounds, Fortran::lower::StatementContext &stmtCtx) {
11111095
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
11121096
if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>(source)) {
11131097
fir::factory::disassociateMutableBox(builder, loc, box);
1098+
cuf::genPointerSync(box.getAddr(), builder);
11141099
return;
11151100
}
11161101
if (converter.getLoweringOptions().getLowerToHighLevelFIR()) {
11171102
fir::ExtendedValue rhs = converter.genExprAddr(loc, source, stmtCtx);
11181103
fir::factory::associateMutableBox(builder, loc, box, rhs, lbounds);
1119-
genCUFPointerSync(box.getAddr(), builder);
1104+
cuf::genPointerSync(box.getAddr(), builder);
11201105
return;
11211106
}
11221107
// The right hand side is not be evaluated into a temp. Array sections can

flang/lib/Lower/Bridge.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "flang/Lower/StatementContext.h"
3535
#include "flang/Lower/Support/Utils.h"
3636
#include "flang/Optimizer/Builder/BoxValue.h"
37+
#include "flang/Optimizer/Builder/CUFCommon.h"
3738
#include "flang/Optimizer/Builder/Character.h"
3839
#include "flang/Optimizer/Builder/FIRBuilder.h"
3940
#include "flang/Optimizer/Builder/Runtime/Assign.h"
@@ -3952,6 +3953,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
39523953
} else {
39533954
fir::MutableBoxValue box = genExprMutableBox(loc, *expr);
39543955
fir::factory::disassociateMutableBox(*builder, loc, box);
3956+
cuf::genPointerSync(box.getAddr(), *builder);
39553957
}
39563958
}
39573959
}

flang/lib/Optimizer/Builder/CUFCommon.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "flang/Optimizer/Builder/CUFCommon.h"
10+
#include "flang/Optimizer/Builder/FIRBuilder.h"
1011
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
12+
#include "flang/Optimizer/HLFIR/HLFIROps.h"
1113
#include "mlir/Dialect/Func/IR/FuncOps.h"
1214
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
1315

@@ -54,3 +56,18 @@ bool cuf::isRegisteredDeviceGlobal(fir::GlobalOp op) {
5456
return true;
5557
return false;
5658
}
59+
60+
void cuf::genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder) {
61+
if (auto declareOp = box.getDefiningOp<hlfir::DeclareOp>()) {
62+
if (auto addrOfOp = declareOp.getMemref().getDefiningOp<fir::AddrOfOp>()) {
63+
auto mod = addrOfOp->getParentOfType<mlir::ModuleOp>();
64+
if (auto globalOp =
65+
mod.lookupSymbol<fir::GlobalOp>(addrOfOp.getSymbol())) {
66+
if (cuf::isRegisteredDeviceGlobal(globalOp)) {
67+
builder.create<cuf::SyncDescriptorOp>(box.getLoc(),
68+
addrOfOp.getSymbol());
69+
}
70+
}
71+
}
72+
}
73+
}

flang/test/Lower/CUDA/cuda-pointer-sync.cuf

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@ use devptr
88
real, device, target, dimension(4) :: a_dev
99
a_dev = 42.0
1010
dev_ptr => a_dev
11+
12+
dev_ptr => null()
13+
14+
nullify(dev_ptr)
1115
end
1216

1317
! CHECK: fir.global @_QMdevptrEdev_ptr {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?xf32>>>
1418
! CHECK-LABEL: func.func @_QQmain()
1519
! CHECK: fir.embox
1620
! CHECK: fir.store
17-
! CHECK: cuf.sync_descriptor @_QMdevptrEdev_ptr
21+
! CHECK-COUNT-3: cuf.sync_descriptor @_QMdevptrEdev_ptr

0 commit comments

Comments
 (0)