Skip to content

Commit 2130285

Browse files
authored
[flang][cuda] Make sure allocator id is set for pointer allocate (#129950)
1 parent 45ca613 commit 2130285

File tree

4 files changed

+40
-23
lines changed

4 files changed

+40
-23
lines changed

flang/include/flang/Lower/Cuda.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace Fortran::lower {
2525
// for it.
2626
// If the insertion point is inside an OpenACC region op, it is considered
2727
// device context.
28-
static bool isCudaDeviceContext(fir::FirOpBuilder &builder) {
28+
static bool inline isCudaDeviceContext(fir::FirOpBuilder &builder) {
2929
if (builder.getRegion().getParentOfType<cuf::KernelOp>())
3030
return true;
3131
if (builder.getRegion()
@@ -41,6 +41,23 @@ static bool isCudaDeviceContext(fir::FirOpBuilder &builder) {
4141
}
4242
return false;
4343
}
44+
45+
static inline unsigned getAllocatorIdx(const Fortran::semantics::Symbol &sym) {
46+
std::optional<Fortran::common::CUDADataAttr> cudaAttr =
47+
Fortran::semantics::GetCUDADataAttr(&sym.GetUltimate());
48+
if (cudaAttr) {
49+
if (*cudaAttr == Fortran::common::CUDADataAttr::Pinned)
50+
return kPinnedAllocatorPos;
51+
if (*cudaAttr == Fortran::common::CUDADataAttr::Device)
52+
return kDeviceAllocatorPos;
53+
if (*cudaAttr == Fortran::common::CUDADataAttr::Managed)
54+
return kManagedAllocatorPos;
55+
if (*cudaAttr == Fortran::common::CUDADataAttr::Unified)
56+
return kUnifiedAllocatorPos;
57+
}
58+
return kDefaultAllocator;
59+
}
60+
4461
} // end namespace Fortran::lower
4562

4663
#endif // FORTRAN_LOWER_CUDA_H

flang/lib/Lower/Allocatable.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ class AllocateStmtHelper {
475475
!alloc.type.IsPolymorphic() &&
476476
!alloc.hasCoarraySpec() && !useAllocateRuntime &&
477477
!box.isPointer();
478+
unsigned allocatorIdx = Fortran::lower::getAllocatorIdx(alloc.getSymbol());
478479

479480
if (inlineAllocation &&
480481
((isCudaSymbol && isCudaDeviceContext) || !isCudaSymbol)) {
@@ -488,7 +489,7 @@ class AllocateStmtHelper {
488489

489490
// Generate a sequence of runtime calls.
490491
errorManager.genStatCheck(builder, loc);
491-
genAllocateObjectInit(box);
492+
genAllocateObjectInit(box, allocatorIdx);
492493
if (alloc.hasCoarraySpec())
493494
TODO(loc, "coarray: allocation of a coarray object");
494495
if (alloc.type.IsPolymorphic())
@@ -549,14 +550,16 @@ class AllocateStmtHelper {
549550
TODO(loc, "derived type length parameters in allocate");
550551
}
551552

552-
void genAllocateObjectInit(const fir::MutableBoxValue &box) {
553+
void genAllocateObjectInit(const fir::MutableBoxValue &box,
554+
unsigned allocatorIdx) {
553555
if (box.isPointer()) {
554556
// For pointers, the descriptor may still be uninitialized (see Fortran
555557
// 2018 19.5.2.2). The allocation runtime needs to be given a descriptor
556558
// with initialized rank, types and attributes. Initialize the descriptor
557559
// here to ensure these constraints are fulfilled.
558560
mlir::Value nullPointer = fir::factory::createUnallocatedBox(
559-
builder, loc, box.getBoxTy(), box.nonDeferredLenParams());
561+
builder, loc, box.getBoxTy(), box.nonDeferredLenParams(),
562+
/*typeSourceBox=*/{}, allocatorIdx);
560563
builder.create<fir::StoreOp>(loc, nullPointer, box.getAddr());
561564
} else {
562565
assert(box.isAllocatable() && "must be an allocatable");
@@ -612,11 +615,12 @@ class AllocateStmtHelper {
612615

613616
void genSourceMoldAllocation(const Allocation &alloc,
614617
const fir::MutableBoxValue &box, bool isSource) {
618+
unsigned allocatorIdx = Fortran::lower::getAllocatorIdx(alloc.getSymbol());
615619
fir::ExtendedValue exv = isSource ? sourceExv : moldExv;
616-
;
620+
617621
// Generate a sequence of runtime calls.
618622
errorManager.genStatCheck(builder, loc);
619-
genAllocateObjectInit(box);
623+
genAllocateObjectInit(box, allocatorIdx);
620624
if (alloc.hasCoarraySpec())
621625
TODO(loc, "coarray: allocation of a coarray object");
622626
// Set length of the allocate object if it has. Otherwise, get the length

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "flang/Lower/ConvertExpr.h"
2020
#include "flang/Lower/ConvertExprToHLFIR.h"
2121
#include "flang/Lower/ConvertProcedureDesignator.h"
22+
#include "flang/Lower/Cuda.h"
2223
#include "flang/Lower/Mangler.h"
2324
#include "flang/Lower/PFTBuilder.h"
2425
#include "flang/Lower/StatementContext.h"
@@ -1985,22 +1986,6 @@ static void genBoxDeclare(Fortran::lower::AbstractConverter &converter,
19851986
replace);
19861987
}
19871988

1988-
static unsigned getAllocatorIdx(const Fortran::semantics::Symbol &sym) {
1989-
std::optional<Fortran::common::CUDADataAttr> cudaAttr =
1990-
Fortran::semantics::GetCUDADataAttr(&sym.GetUltimate());
1991-
if (cudaAttr) {
1992-
if (*cudaAttr == Fortran::common::CUDADataAttr::Pinned)
1993-
return kPinnedAllocatorPos;
1994-
if (*cudaAttr == Fortran::common::CUDADataAttr::Device)
1995-
return kDeviceAllocatorPos;
1996-
if (*cudaAttr == Fortran::common::CUDADataAttr::Managed)
1997-
return kManagedAllocatorPos;
1998-
if (*cudaAttr == Fortran::common::CUDADataAttr::Unified)
1999-
return kUnifiedAllocatorPos;
2000-
}
2001-
return kDefaultAllocator;
2002-
}
2003-
20041989
/// Lower specification expressions and attributes of variable \p var and
20051990
/// add it to the symbol map. For a global or an alias, the address must be
20061991
/// pre-computed and provided in \p preAlloc. A dummy argument for the current
@@ -2091,7 +2076,7 @@ void Fortran::lower::mapSymbolAttributes(
20912076
converter, loc, var, boxAlloc, nonDeferredLenParams,
20922077
/*alwaysUseBox=*/
20932078
converter.getLoweringOptions().getLowerToHighLevelFIR(),
2094-
getAllocatorIdx(var.getSymbol()));
2079+
Fortran::lower::getAllocatorIdx(var.getSymbol()));
20952080
genAllocatableOrPointerDeclare(converter, symMap, var.getSymbol(), box,
20962081
replace);
20972082
return;
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
3+
! Test lowering of CUDA pointers.
4+
5+
subroutine allocate_pointer
6+
real, device, pointer :: pr(:)
7+
allocate(pr(10))
8+
end
9+
10+
! CHECK-LABEL: func.func @_QPallocate_pointer()
11+
! CHECK-COUNT-2: fir.embox %{{.*}} {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>

0 commit comments

Comments
 (0)