Skip to content

Commit 6e498bc

Browse files
authored
[flang][cuda] Handle simple device pointer allocation (#123996)
1 parent a939a9f commit 6e498bc

File tree

5 files changed

+102
-4
lines changed

5 files changed

+102
-4
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===-- include/flang/Runtime/CUDA/pointer.h --------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef FORTRAN_RUNTIME_CUDA_POINTER_H_
10+
#define FORTRAN_RUNTIME_CUDA_POINTER_H_
11+
12+
#include "flang/Runtime/descriptor-consts.h"
13+
#include "flang/Runtime/entry-names.h"
14+
15+
namespace Fortran::runtime::cuda {
16+
17+
extern "C" {
18+
19+
/// Perform allocation of the descriptor.
20+
int RTDECL(CUFPointerAllocate)(Descriptor &, int64_t stream = -1,
21+
bool hasStat = false, const Descriptor *errMsg = nullptr,
22+
const char *sourceFile = nullptr, int sourceLine = 0);
23+
24+
} // extern "C"
25+
26+
} // namespace Fortran::runtime::cuda
27+
#endif // FORTRAN_RUNTIME_CUDA_POINTER_H_

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "flang/Runtime/CUDA/common.h"
2121
#include "flang/Runtime/CUDA/descriptor.h"
2222
#include "flang/Runtime/CUDA/memory.h"
23+
#include "flang/Runtime/CUDA/pointer.h"
2324
#include "flang/Runtime/allocatable.h"
2425
#include "mlir/Conversion/LLVMCommon/Pattern.h"
2526
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -161,7 +162,18 @@ struct CUFAllocateOpConversion
161162
fir::FirOpBuilder builder(rewriter, mod);
162163
mlir::Location loc = op.getLoc();
163164

165+
bool isPointer = false;
166+
167+
if (auto declareOp =
168+
mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp()))
169+
if (declareOp.getFortranAttrs() &&
170+
bitEnumContainsAny(*declareOp.getFortranAttrs(),
171+
fir::FortranVariableFlagsEnum::pointer))
172+
isPointer = true;
173+
164174
if (hasDoubleDescriptors(op)) {
175+
if (isPointer)
176+
TODO(loc, "pointer allocation with double descriptors");
165177
// Allocation for module variable are done with custom runtime entry point
166178
// so the descriptors can be synchronized.
167179
mlir::func::FuncOp func;
@@ -176,13 +188,20 @@ struct CUFAllocateOpConversion
176188
}
177189

178190
mlir::func::FuncOp func;
179-
if (op.getSource())
191+
if (op.getSource()) {
192+
if (isPointer)
193+
TODO(loc, "pointer allocation with source");
180194
func =
181195
fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocateSource)>(
182196
loc, builder);
183-
else
184-
func = fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>(
185-
loc, builder);
197+
} else {
198+
func =
199+
isPointer
200+
? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocate)>(
201+
loc, builder)
202+
: fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>(
203+
loc, builder);
204+
}
186205

187206
return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
188207
}

flang/runtime/CUDA/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ add_flang_library(${CUFRT_LIBNAME}
2020
kernel.cpp
2121
memmove-function.cpp
2222
memory.cpp
23+
pointer.cpp
2324
registration.cpp
2425
)
2526

flang/runtime/CUDA/pointer.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//===-- runtime/CUDA/pointer.cpp ------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Runtime/CUDA/pointer.h"
10+
#include "../stat.h"
11+
#include "../terminator.h"
12+
#include "flang/Runtime/pointer.h"
13+
14+
#include "cuda_runtime.h"
15+
16+
namespace Fortran::runtime::cuda {
17+
18+
extern "C" {
19+
RT_EXT_API_GROUP_BEGIN
20+
21+
int RTDEF(CUFPointerAllocate)(Descriptor &desc, int64_t stream, bool hasStat,
22+
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
23+
if (desc.HasAddendum()) {
24+
Terminator terminator{sourceFile, sourceLine};
25+
// TODO: This require a bit more work to set the correct type descriptor
26+
// address
27+
terminator.Crash(
28+
"not yet implemented: CUDA descriptor allocation with addendum");
29+
}
30+
// Perform the standard allocation.
31+
int stat{
32+
RTNAME(PointerAllocate)(desc, hasStat, errMsg, sourceFile, sourceLine)};
33+
return stat;
34+
}
35+
36+
RT_EXT_API_GROUP_END
37+
38+
} // extern "C"
39+
40+
} // namespace Fortran::runtime::cuda

flang/test/Fir/CUDA/cuda-allocate.fir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,4 +181,15 @@ func.func @_QQallocate_stream() {
181181
// CHECK: %[[STREAM_LOAD:.*]] = fir.load %[[STREAM]] : !fir.ref<i64>
182182
// CHECK: fir.call @_FortranACUFAllocatableAllocate(%{{.*}}, %[[STREAM_LOAD]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i64, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
183183

184+
185+
func.func @_QPp_alloc() {
186+
%0 = cuf.alloc !fir.box<!fir.ptr<!fir.array<?xcomplex<f32>>>> {bindc_name = "complex_array", data_attr = #cuf.cuda<device>, uniq_name = "_QFp_allocEcomplex_array"} -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xcomplex<f32>>>>>
187+
%4 = fir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFp_allocEcomplex_array"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xcomplex<f32>>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xcomplex<f32>>>>>
188+
%9 = cuf.allocate %4 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xcomplex<f32>>>>> {data_attr = #cuf.cuda<device>} -> i32
189+
return
190+
}
191+
192+
// CHECK-LABEL: func.func @_QPp_alloc()
193+
// CHECK: fir.call @_FortranACUFPointerAllocate
194+
184195
} // end of module

0 commit comments

Comments
 (0)