Skip to content

Commit 478e516

Browse files
authored
[flang][cuda] Sync double descriptor after c_f_pointer call (llvm#130194)
After a global device pointer is set through `c_f_pointer`, we need to sync the double descriptor so the version on the device is also up to date.
1 parent 55f86cf commit 478e516

File tree

12 files changed

+139
-49
lines changed

12 files changed

+139
-49
lines changed

flang/include/flang/Lower/Cuda.h

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,6 @@
2020
#include "mlir/Dialect/OpenACC/OpenACC.h"
2121

2222
namespace Fortran::lower {
23-
// Check if the insertion point is currently in a device context. HostDevice
24-
// subprogram are not considered fully device context so it will return false
25-
// for it.
26-
// If the insertion point is inside an OpenACC region op, it is considered
27-
// device context.
28-
static bool inline isCudaDeviceContext(fir::FirOpBuilder &builder) {
29-
if (builder.getRegion().getParentOfType<cuf::KernelOp>())
30-
return true;
31-
if (builder.getRegion()
32-
.getParentOfType<mlir::acc::ComputeRegionOpInterface>())
33-
return true;
34-
if (auto funcOp = builder.getRegion().getParentOfType<mlir::func::FuncOp>()) {
35-
if (auto cudaProcAttr =
36-
funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
37-
cuf::getProcAttrName())) {
38-
return cudaProcAttr.getValue() != cuf::ProcAttribute::Host &&
39-
cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice;
40-
}
41-
}
42-
return false;
43-
}
4423

4524
static inline unsigned getAllocatorIdx(const Fortran::semantics::Symbol &sym) {
4625
std::optional<Fortran::common::CUDADataAttr> cudaAttr =

flang/include/flang/Optimizer/Builder/CUFCommon.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ namespace cuf {
2525
mlir::gpu::GPUModuleOp getOrCreateGPUModule(mlir::ModuleOp mod,
2626
mlir::SymbolTable &symTab);
2727

28-
bool isInCUDADeviceContext(mlir::Operation *op);
28+
bool isCUDADeviceContext(mlir::Operation *op);
29+
bool isCUDADeviceContext(mlir::Region &);
2930
bool isRegisteredDeviceGlobal(fir::GlobalOp op);
31+
bool isRegisteredDeviceAttr(std::optional<cuf::DataAttribute> attr);
3032

3133
void genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder);
3234

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===-- Descriptor.h - CUDA descritpor runtime API calls --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef FORTRAN_OPTIMIZER_BUILDER_RUNTIME_CUDA_DESCRIPTOR_H_
10+
#define FORTRAN_OPTIMIZER_BUILDER_RUNTIME_CUDA_DESCRIPTOR_H_
11+
12+
#include "mlir/IR/Value.h"
13+
14+
namespace mlir {
15+
class Location;
16+
} // namespace mlir
17+
18+
namespace fir {
19+
class FirOpBuilder;
20+
}
21+
22+
namespace fir::runtime::cuda {
23+
24+
/// Generate runtime call to sync the doublce descriptor referenced by
25+
/// \p hostPtr.
26+
void genSyncGlobalDescriptor(fir::FirOpBuilder &builder, mlir::Location loc,
27+
mlir::Value hostPtr);
28+
29+
} // namespace fir::runtime::cuda
30+
31+
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_CUDA_DESCRIPTOR_H_

flang/lib/Lower/Allocatable.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ class AllocateStmtHelper {
470470
void genSimpleAllocation(const Allocation &alloc,
471471
const fir::MutableBoxValue &box) {
472472
bool isCudaSymbol = Fortran::semantics::HasCUDAAttr(alloc.getSymbol());
473-
bool isCudaDeviceContext = Fortran::lower::isCudaDeviceContext(builder);
473+
bool isCudaDeviceContext = cuf::isCUDADeviceContext(builder.getRegion());
474474
bool inlineAllocation = !box.isDerived() && !errorManager.hasStatSpec() &&
475475
!alloc.type.IsPolymorphic() &&
476476
!alloc.hasCoarraySpec() && !useAllocateRuntime &&
@@ -862,7 +862,7 @@ genDeallocate(fir::FirOpBuilder &builder,
862862
mlir::Value declaredTypeDesc = {},
863863
const Fortran::semantics::Symbol *symbol = nullptr) {
864864
bool isCudaSymbol = symbol && Fortran::semantics::HasCUDAAttr(*symbol);
865-
bool isCudaDeviceContext = Fortran::lower::isCudaDeviceContext(builder);
865+
bool isCudaDeviceContext = cuf::isCUDADeviceContext(builder.getRegion());
866866
bool inlineDeallocation =
867867
!box.isDerived() && !box.isPolymorphic() && !box.hasAssumedRank() &&
868868
!box.isUnlimitedPolymorphic() && !errorManager.hasStatSpec() &&

flang/lib/Lower/Bridge.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4689,7 +4689,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
46894689
mlir::Location loc = getCurrentLocation();
46904690
fir::FirOpBuilder &builder = getFirOpBuilder();
46914691

4692-
bool isInDeviceContext = Fortran::lower::isCudaDeviceContext(builder);
4692+
bool isInDeviceContext = cuf::isCUDADeviceContext(builder.getRegion());
46934693

46944694
bool isCUDATransfer =
46954695
IsCUDADataTransfer(assign.lhs, assign.rhs) && !isInDeviceContext;

flang/lib/Optimizer/Builder/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ add_flang_library(FIRBuilder
1818
Runtime/Assign.cpp
1919
Runtime/Character.cpp
2020
Runtime/Command.cpp
21+
Runtime/CUDA/Descriptor.cpp
2122
Runtime/Derived.cpp
2223
Runtime/EnvironmentDefaults.cpp
2324
Runtime/Exceptions.cpp

flang/lib/Optimizer/Builder/CUFCommon.cpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "flang/Optimizer/HLFIR/HLFIROps.h"
1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
1414
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
15+
#include "mlir/Dialect/OpenACC/OpenACC.h"
1516

1617
/// Retrieve or create the CUDA Fortran GPU module in the give in \p mod.
1718
mlir::gpu::GPUModuleOp cuf::getOrCreateGPUModule(mlir::ModuleOp mod,
@@ -31,32 +32,47 @@ mlir::gpu::GPUModuleOp cuf::getOrCreateGPUModule(mlir::ModuleOp mod,
3132
return gpuMod;
3233
}
3334

34-
bool cuf::isInCUDADeviceContext(mlir::Operation *op) {
35-
if (!op)
35+
bool cuf::isCUDADeviceContext(mlir::Operation *op) {
36+
if (!op || !op->getParentRegion())
3637
return false;
37-
if (op->getParentOfType<cuf::KernelOp>() ||
38-
op->getParentOfType<mlir::gpu::GPUFuncOp>())
38+
return isCUDADeviceContext(*op->getParentRegion());
39+
}
40+
41+
// Check if the insertion point is currently in a device context. HostDevice
42+
// subprogram are not considered fully device context so it will return false
43+
// for it.
44+
// If the insertion point is inside an OpenACC region op, it is considered
45+
// device context.
46+
bool cuf::isCUDADeviceContext(mlir::Region &region) {
47+
if (region.getParentOfType<cuf::KernelOp>())
48+
return true;
49+
if (region.getParentOfType<mlir::acc::ComputeRegionOpInterface>())
3950
return true;
40-
if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) {
41-
if (auto cudaProcAttr = funcOp->getAttrOfType<cuf::ProcAttributeAttr>(
42-
cuf::getProcAttrName())) {
43-
return cudaProcAttr.getValue() != cuf::ProcAttribute::Host;
51+
if (auto funcOp = region.getParentOfType<mlir::func::FuncOp>()) {
52+
if (auto cudaProcAttr =
53+
funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
54+
cuf::getProcAttrName())) {
55+
return cudaProcAttr.getValue() != cuf::ProcAttribute::Host &&
56+
cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice;
4457
}
4558
}
4659
return false;
4760
}
4861

49-
bool cuf::isRegisteredDeviceGlobal(fir::GlobalOp op) {
50-
if (op.getConstant())
51-
return false;
52-
auto attr = op.getDataAttr();
62+
bool cuf::isRegisteredDeviceAttr(std::optional<cuf::DataAttribute> attr) {
5363
if (attr && (*attr == cuf::DataAttribute::Device ||
5464
*attr == cuf::DataAttribute::Managed ||
5565
*attr == cuf::DataAttribute::Constant))
5666
return true;
5767
return false;
5868
}
5969

70+
bool cuf::isRegisteredDeviceGlobal(fir::GlobalOp op) {
71+
if (op.getConstant())
72+
return false;
73+
return isRegisteredDeviceAttr(op.getDataAttr());
74+
}
75+
6076
void cuf::genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder) {
6177
if (auto declareOp = box.getDefiningOp<hlfir::DeclareOp>()) {
6278
if (auto addrOfOp = declareOp.getMemref().getDefiningOp<fir::AddrOfOp>()) {

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
#include "flang/Optimizer/Builder/IntrinsicCall.h"
1717
#include "flang/Common/static-multimap-view.h"
1818
#include "flang/Optimizer/Builder/BoxValue.h"
19+
#include "flang/Optimizer/Builder/CUFCommon.h"
1920
#include "flang/Optimizer/Builder/Character.h"
2021
#include "flang/Optimizer/Builder/Complex.h"
2122
#include "flang/Optimizer/Builder/FIRBuilder.h"
2223
#include "flang/Optimizer/Builder/MutableBox.h"
2324
#include "flang/Optimizer/Builder/PPCIntrinsicCall.h"
2425
#include "flang/Optimizer/Builder/Runtime/Allocatable.h"
26+
#include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h"
2527
#include "flang/Optimizer/Builder/Runtime/Character.h"
2628
#include "flang/Optimizer/Builder/Runtime/Command.h"
2729
#include "flang/Optimizer/Builder/Runtime/Derived.h"
@@ -38,6 +40,7 @@
3840
#include "flang/Optimizer/Dialect/FIROps.h"
3941
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
4042
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
43+
#include "flang/Optimizer/HLFIR/HLFIROps.h"
4144
#include "flang/Optimizer/Support/FatalError.h"
4245
#include "flang/Optimizer/Support/Utils.h"
4346
#include "flang/Runtime/entry-names.h"
@@ -3254,6 +3257,17 @@ void IntrinsicLibrary::genCFPointer(llvm::ArrayRef<fir::ExtendedValue> args) {
32543257

32553258
fir::factory::associateMutableBox(builder, loc, *fPtr, getCPtrExtVal(*fPtr),
32563259
/*lbounds=*/mlir::ValueRange{});
3260+
3261+
// If the pointer is a registered CUDA fortran variable, the descriptor needs
3262+
// to be synced.
3263+
if (auto declare = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
3264+
fPtr->getAddr().getDefiningOp()))
3265+
if (declare.getMemref().getDefiningOp() &&
3266+
mlir::isa<fir::AddrOfOp>(declare.getMemref().getDefiningOp()))
3267+
if (cuf::isRegisteredDeviceAttr(declare.getDataAttr()) &&
3268+
!cuf::isCUDADeviceContext(builder.getRegion()))
3269+
fir::runtime::cuda::genSyncGlobalDescriptor(builder, loc,
3270+
declare.getMemref());
32573271
}
32583272

32593273
// C_F_PROCPOINTER
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
2+
//===-- Allocatable.cpp -- Allocatable statements lowering ----------------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h"
15+
#include "flang/Optimizer/Builder/FIRBuilder.h"
16+
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
17+
#include "flang/Runtime/CUDA/descriptor.h"
18+
19+
using namespace Fortran::runtime::cuda;
20+
21+
void fir::runtime::cuda::genSyncGlobalDescriptor(fir::FirOpBuilder &builder,
22+
mlir::Location loc,
23+
mlir::Value hostPtr) {
24+
mlir::func::FuncOp callee =
25+
fir::runtime::getRuntimeFunc<mkRTKey(CUFSyncGlobalDescriptor)>(loc,
26+
builder);
27+
auto fTy = callee.getFunctionType();
28+
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
29+
mlir::Value sourceLine =
30+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
31+
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
32+
builder, loc, fTy, hostPtr, sourceFile, sourceLine)};
33+
builder.create<fir::CallOp>(loc, callee, args);
34+
}

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "flang/Optimizer/Transforms/CUFOpConversion.h"
1010
#include "flang/Optimizer/Builder/CUFCommon.h"
11+
#include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h"
1112
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
1213
#include "flang/Optimizer/CodeGen/TypeConverter.h"
1314
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
@@ -904,16 +905,7 @@ struct CUFSyncDescriptorOpConversion
904905

905906
auto hostAddr = builder.create<fir::AddrOfOp>(
906907
loc, fir::ReferenceType::get(globalOp.getType()), op.getGlobalName());
907-
mlir::func::FuncOp callee =
908-
fir::runtime::getRuntimeFunc<mkRTKey(CUFSyncGlobalDescriptor)>(loc,
909-
builder);
910-
auto fTy = callee.getFunctionType();
911-
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
912-
mlir::Value sourceLine =
913-
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
914-
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
915-
builder, loc, fTy, hostAddr, sourceFile, sourceLine)};
916-
builder.create<fir::CallOp>(loc, callee, args);
908+
fir::runtime::cuda::genSyncGlobalDescriptor(builder, loc, hostAddr);
917909
op.erase();
918910
return mlir::success();
919911
}

flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1279,7 +1279,7 @@ void SimplifyIntrinsicsPass::runOnOperation() {
12791279
fir::KindMapping kindMap = fir::getKindMapping(module);
12801280
module.walk([&](mlir::Operation *op) {
12811281
if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
1282-
if (cuf::isInCUDADeviceContext(op))
1282+
if (cuf::isCUDADeviceContext(op))
12831283
return;
12841284
if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
12851285
mlir::StringRef funcName = callee.getLeafReference().getValue();

flang/test/Lower/CUDA/cuda-pointer.cuf

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,31 @@
22

33
! Test lowering of CUDA pointers.
44

5+
module mod1
6+
7+
integer, device, pointer :: x(:)
8+
9+
contains
10+
511
subroutine allocate_pointer
612
real, device, pointer :: pr(:)
713
allocate(pr(10))
814
end
915

10-
! CHECK-LABEL: func.func @_QPallocate_pointer()
16+
! CHECK-LABEL: func.func @_QMmod1Pallocate_pointer()
1117
! CHECK-COUNT-2: fir.embox %{{.*}} {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
18+
19+
subroutine c_f_pointer_sync
20+
use iso_c_binding
21+
use, intrinsic :: __fortran_builtins, only: c_devptr => __builtin_c_devptr
22+
type(c_devptr) :: cd1
23+
integer, parameter :: N = 2000
24+
call c_f_pointer(cd1, x, (/ 2000 /))
25+
end
26+
27+
! CHECK-LABEL: func.func @_QMmod1Pc_f_pointer_sync()
28+
! CHECK: %[[ADDR_X:.*]] = fir.address_of(@_QMmod1Ex) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
29+
! CHECK: %[[CONV:.*]] = fir.convert %[[ADDR_X]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) -> !fir.llvm_ptr<i8>
30+
! CHECK: fir.call @_FortranACUFSyncGlobalDescriptor(%[[CONV]], %{{.*}}, %{{.*}}) fastmath<contract> : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> ()
31+
32+
end module

0 commit comments

Comments
 (0)