Skip to content

Commit 1fcb6a9

Browse files
authored
[flang][OpenMP] Initialize allocatable members of derived types (#120295)
Allocatable members of privatized derived types must be allocated, with the same bounds as the original object, whenever that member is also allocated in it, but Flang was not performing such initialization. The `Initialize` runtime function can't perform this task unless its signature is changed to receive an additional parameter, the original object, that is needed to find out which allocatable members, with their bounds, must also be allocated in the clone. As `Initialize` is used not only for privatization, sometimes this other object won't even exist, so this new parameter would need to be optional. Because of this, it seemed better to add a new runtime function: `InitializeClone`. To avoid unnecessary calls, lowering inserts a call to it only for privatized items that are derived types with allocatable members. Fixes #114888 Fixes #114889
1 parent bdf2555 commit 1fcb6a9

File tree

13 files changed

+262
-4
lines changed

13 files changed

+262
-4
lines changed

flang/include/flang/Lower/AbstractConverter.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ class AbstractConverter {
8888
/// Get the mlir instance of a symbol.
8989
virtual mlir::Value getSymbolAddress(SymbolRef sym) = 0;
9090

91+
virtual fir::ExtendedValue
92+
symBoxToExtendedValue(const Fortran::lower::SymbolBox &symBox) = 0;
93+
9194
virtual fir::ExtendedValue
9295
getSymbolExtendedValue(const Fortran::semantics::Symbol &sym,
9396
Fortran::lower::SymMap *symMap = nullptr) = 0;

flang/include/flang/Lower/ConvertVariable.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ void defaultInitializeAtRuntime(Fortran::lower::AbstractConverter &converter,
7070
const Fortran::semantics::Symbol &sym,
7171
Fortran::lower::SymMap &symMap);
7272

73+
/// Call clone initialization runtime routine to initialize \p sym's value.
74+
void initializeCloneAtRuntime(Fortran::lower::AbstractConverter &converter,
75+
const Fortran::semantics::Symbol &sym,
76+
Fortran::lower::SymMap &symMap);
77+
7378
/// Create a fir::GlobalOp given a module variable definition. This is intended
7479
/// to be used when lowering a module definition, not when lowering variables
7580
/// used from a module. For used variables instantiateVariable must directly be

flang/include/flang/Optimizer/Builder/Runtime/Derived.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ namespace fir::runtime {
2626
void genDerivedTypeInitialize(fir::FirOpBuilder &builder, mlir::Location loc,
2727
mlir::Value box);
2828

29+
/// Generate call to derived type clone initialization runtime routine to
30+
/// initialize \p newBox from \p box.
31+
void genDerivedTypeInitializeClone(fir::FirOpBuilder &builder,
32+
mlir::Location loc, mlir::Value newBox,
33+
mlir::Value box);
34+
2935
/// Generate call to derived type destruction runtime routine to
3036
/// destroy \p box.
3137
void genDerivedTypeDestroy(fir::FirOpBuilder &builder, mlir::Location loc,

flang/include/flang/Runtime/derived-api.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ extern "C" {
3232
void RTDECL(Initialize)(
3333
const Descriptor &, const char *sourceFile = nullptr, int sourceLine = 0);
3434

35+
// Initializes an object clone from the original object.
36+
// Each allocatable member of the clone is allocated with the same bounds as
37+
// in the original object, if it is also allocated in it.
38+
// The descriptor must be initialized and non-null.
39+
void RTDECL(InitializeClone)(const Descriptor &, const Descriptor &,
40+
const char *sourceFile = nullptr, int sourceLine = 0);
41+
3542
// Finalizes an object and its components. Deallocates any
3643
// allocatable/automatic components. Does not deallocate the descriptor's
3744
// storage.

flang/lib/Lower/Bridge.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,8 +557,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
557557
return lookupSymbol(sym).getAddr();
558558
}
559559

560-
fir::ExtendedValue
561-
symBoxToExtendedValue(const Fortran::lower::SymbolBox &symBox) {
560+
fir::ExtendedValue symBoxToExtendedValue(
561+
const Fortran::lower::SymbolBox &symBox) override final {
562562
return symBox.match(
563563
[](const Fortran::lower::SymbolBox::Intrinsic &box)
564564
-> fir::ExtendedValue { return box.getAddr(); },

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,20 @@ void Fortran::lower::defaultInitializeAtRuntime(
798798
}
799799
}
800800

801+
/// Call clone initialization runtime routine to initialize \p sym's value.
802+
void Fortran::lower::initializeCloneAtRuntime(
803+
Fortran::lower::AbstractConverter &converter,
804+
const Fortran::semantics::Symbol &sym, Fortran::lower::SymMap &symMap) {
805+
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
806+
mlir::Location loc = converter.getCurrentLocation();
807+
fir::ExtendedValue exv = converter.getSymbolExtendedValue(sym, &symMap);
808+
mlir::Value newBox = builder.createBox(loc, exv);
809+
lower::SymbolBox hsb = converter.lookupOneLevelUpSymbol(sym);
810+
fir::ExtendedValue hexv = converter.symBoxToExtendedValue(hsb);
811+
mlir::Value box = builder.createBox(loc, hexv);
812+
fir::runtime::genDerivedTypeInitializeClone(builder, loc, newBox, box);
813+
}
814+
801815
enum class VariableCleanUp { Finalize, Deallocate };
802816
/// Check whether a local variable needs to be finalized according to clause
803817
/// 7.5.6.3 point 3 or if it is an allocatable that must be deallocated. Note

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,23 @@ void DataSharingProcessor::cloneSymbol(const semantics::Symbol *sym) {
116116
*sym, /*skipDefaultInit=*/isFirstPrivate);
117117
(void)success;
118118
assert(success && "Privatization failed due to existing binding");
119+
120+
// Initialize clone from original object if it has any allocatable member.
121+
auto needInitClone = [&] {
122+
if (isFirstPrivate)
123+
return false;
124+
125+
SymbolBox sb = symTable.lookupSymbol(sym);
126+
assert(sb);
127+
mlir::Value addr = sb.getAddr();
128+
assert(addr);
129+
return hlfir::mayHaveAllocatableComponent(addr.getType());
130+
};
131+
132+
if (needInitClone()) {
133+
Fortran::lower::initializeCloneAtRuntime(converter, *sym, symTable);
134+
callsInitClone = true;
135+
}
119136
}
120137

121138
void DataSharingProcessor::copyFirstPrivateSymbol(
@@ -165,8 +182,8 @@ bool DataSharingProcessor::needBarrier() {
165182
// variables.
166183
// Emit implicit barrier for linear clause. Maybe on somewhere else.
167184
for (const semantics::Symbol *sym : allPrivatizedSymbols) {
168-
if (sym->test(semantics::Symbol::Flag::OmpFirstPrivate) &&
169-
sym->test(semantics::Symbol::Flag::OmpLastPrivate))
185+
if (sym->test(semantics::Symbol::Flag::OmpLastPrivate) &&
186+
(sym->test(semantics::Symbol::Flag::OmpFirstPrivate) || callsInitClone))
170187
return true;
171188
}
172189
return false;

flang/lib/Lower/OpenMP/DataSharingProcessor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class DataSharingProcessor {
8686
lower::pft::Evaluation &eval;
8787
bool shouldCollectPreDeterminedSymbols;
8888
bool useDelayedPrivatization;
89+
bool callsInitClone = false;
8990
lower::SymMap &symTable;
9091
OMPConstructSymbolVisitor visitor;
9192

flang/lib/Optimizer/Builder/Runtime/Derived.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@ void fir::runtime::genDerivedTypeInitialize(fir::FirOpBuilder &builder,
2929
builder.create<fir::CallOp>(loc, func, args);
3030
}
3131

32+
void fir::runtime::genDerivedTypeInitializeClone(fir::FirOpBuilder &builder,
33+
mlir::Location loc,
34+
mlir::Value newBox,
35+
mlir::Value box) {
36+
auto func =
37+
fir::runtime::getRuntimeFunc<mkRTKey(InitializeClone)>(loc, builder);
38+
auto fTy = func.getFunctionType();
39+
auto sourceFile = fir::factory::locationToFilename(builder, loc);
40+
auto sourceLine =
41+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
42+
auto args = fir::runtime::createArguments(builder, loc, fTy, newBox, box,
43+
sourceFile, sourceLine);
44+
builder.create<fir::CallOp>(loc, func, args);
45+
}
46+
3247
void fir::runtime::genDerivedTypeDestroy(fir::FirOpBuilder &builder,
3348
mlir::Location loc, mlir::Value box) {
3449
auto func = fir::runtime::getRuntimeFunc<mkRTKey(Destroy)>(loc, builder);

flang/runtime/derived-api.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,16 @@ void RTDEF(Initialize)(
3131
}
3232
}
3333

34+
void RTDEF(InitializeClone)(const Descriptor &clone, const Descriptor &orig,
35+
const char *sourceFile, int sourceLine) {
36+
if (const DescriptorAddendum * addendum{clone.Addendum()}) {
37+
if (const auto *derived{addendum->derivedType()}) {
38+
Terminator terminator{sourceFile, sourceLine};
39+
InitializeClone(clone, orig, *derived, terminator);
40+
}
41+
}
42+
}
43+
3444
void RTDEF(Destroy)(const Descriptor &descriptor) {
3545
if (const DescriptorAddendum * addendum{descriptor.Addendum()}) {
3646
if (const auto *derived{addendum->derivedType()}) {

flang/runtime/derived.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,84 @@ RT_API_ATTRS int Initialize(const Descriptor &instance,
122122
return stat;
123123
}
124124

125+
RT_API_ATTRS int InitializeClone(const Descriptor &clone,
126+
const Descriptor &orig, const typeInfo::DerivedType &derived,
127+
Terminator &terminator, bool hasStat, const Descriptor *errMsg) {
128+
const Descriptor &componentDesc{derived.component()};
129+
std::size_t elements{orig.Elements()};
130+
int stat{StatOk};
131+
132+
// Initialize each data component.
133+
std::size_t components{componentDesc.Elements()};
134+
for (std::size_t i{0}; i < components; ++i) {
135+
const typeInfo::Component &comp{
136+
*componentDesc.ZeroBasedIndexedElement<typeInfo::Component>(i)};
137+
SubscriptValue at[maxRank];
138+
orig.GetLowerBounds(at);
139+
// Allocate allocatable components that are also allocated in the original
140+
// object.
141+
if (comp.genre() == typeInfo::Component::Genre::Allocatable) {
142+
// Initialize each element.
143+
for (std::size_t j{0}; j < elements; ++j, orig.IncrementSubscripts(at)) {
144+
Descriptor &origDesc{
145+
*orig.ElementComponent<Descriptor>(at, comp.offset())};
146+
Descriptor &cloneDesc{
147+
*clone.ElementComponent<Descriptor>(at, comp.offset())};
148+
if (origDesc.IsAllocated()) {
149+
cloneDesc.ApplyMold(origDesc, origDesc.rank());
150+
stat = ReturnError(terminator, cloneDesc.Allocate(), errMsg, hasStat);
151+
if (stat == StatOk) {
152+
if (const DescriptorAddendum * addendum{cloneDesc.Addendum()}) {
153+
if (const typeInfo::DerivedType *
154+
derived{addendum->derivedType()}) {
155+
if (!derived->noInitializationNeeded()) {
156+
// Perform default initialization for the allocated element.
157+
stat = Initialize(
158+
cloneDesc, *derived, terminator, hasStat, errMsg);
159+
}
160+
// Initialize derived type's allocatables.
161+
if (stat == StatOk) {
162+
stat = InitializeClone(cloneDesc, origDesc, *derived,
163+
terminator, hasStat, errMsg);
164+
}
165+
}
166+
}
167+
}
168+
}
169+
if (stat != StatOk) {
170+
break;
171+
}
172+
}
173+
} else if (comp.genre() == typeInfo::Component::Genre::Data &&
174+
comp.derivedType()) {
175+
// Handle nested derived types.
176+
const typeInfo::DerivedType &compType{*comp.derivedType()};
177+
SubscriptValue extents[maxRank];
178+
GetComponentExtents(extents, comp, orig);
179+
// Data components don't have descriptors, allocate them.
180+
StaticDescriptor<maxRank, true, 0> origStaticDesc;
181+
StaticDescriptor<maxRank, true, 0> cloneStaticDesc;
182+
Descriptor &origDesc{origStaticDesc.descriptor()};
183+
Descriptor &cloneDesc{cloneStaticDesc.descriptor()};
184+
// Initialize each element.
185+
for (std::size_t j{0}; j < elements; ++j, orig.IncrementSubscripts(at)) {
186+
origDesc.Establish(compType,
187+
orig.ElementComponent<char>(at, comp.offset()), comp.rank(),
188+
extents);
189+
cloneDesc.Establish(compType,
190+
clone.ElementComponent<char>(at, comp.offset()), comp.rank(),
191+
extents);
192+
stat = InitializeClone(
193+
cloneDesc, origDesc, compType, terminator, hasStat, errMsg);
194+
if (stat != StatOk) {
195+
break;
196+
}
197+
}
198+
}
199+
}
200+
return stat;
201+
}
202+
125203
static RT_API_ATTRS const typeInfo::SpecialBinding *FindFinal(
126204
const typeInfo::DerivedType &derived, int rank) {
127205
if (const auto *ranked{derived.FindSpecialBinding(

flang/runtime/derived.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ class Terminator;
2626
RT_API_ATTRS int Initialize(const Descriptor &, const typeInfo::DerivedType &,
2727
Terminator &, bool hasStat = false, const Descriptor *errMsg = nullptr);
2828

29+
// Initializes an object clone from the original object.
30+
// Each allocatable member of the clone is allocated with the same bounds as
31+
// in the original object, if it is also allocated in it.
32+
// Returns a STAT= code (0 when all's well).
33+
RT_API_ATTRS int InitializeClone(const Descriptor &, const Descriptor &,
34+
const typeInfo::DerivedType &, Terminator &, bool hasStat = false,
35+
const Descriptor *errMsg = nullptr);
36+
2937
// Call FINAL subroutines, if any
3038
RT_API_ATTRS void Finalize(
3139
const Descriptor &, const typeInfo::DerivedType &derived, Terminator *);
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
! Test that derived type allocatable members of private copies are properly
2+
! initialized.
3+
!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
4+
5+
module m1
6+
type x
7+
integer, allocatable :: x1(:)
8+
end type
9+
10+
type y
11+
integer :: y1(10)
12+
end type
13+
14+
contains
15+
16+
!CHECK-LABEL: omp.private {type = private} @_QMm1Ftest_nested
17+
!CHECK: fir.call @_FortranAInitializeClone
18+
!CHECK-NEXT: omp.yield
19+
20+
!CHECK-LABEL: omp.private {type = private} @_QMm1Ftest_array_of_allocs
21+
!CHECK: fir.call @_FortranAInitializeClone
22+
!CHECK-NEXT: omp.yield
23+
24+
!CHECK-LABEL: omp.private {type = firstprivate} @_QMm1Ftest_array
25+
!CHECK-NOT: fir.call @_FortranAInitializeClone
26+
!CHECK: omp.yield
27+
28+
!CHECK-LABEL: omp.private {type = private} @_QMm1Ftest_array
29+
!CHECK: fir.call @_FortranAInitializeClone
30+
!CHECK-NEXT: omp.yield
31+
32+
!CHECK-LABEL: omp.private {type = private} @_QMm1Ftest_scalar
33+
!CHECK: fir.call @_FortranAInitializeClone
34+
!CHECK-NEXT: omp.yield
35+
36+
subroutine test_scalar()
37+
type(x) :: v
38+
allocate(v%x1(5))
39+
40+
!$omp parallel private(v)
41+
!$omp end parallel
42+
end subroutine
43+
44+
! Test omp sections lastprivate(v, v2)
45+
! - InitializeClone must not be called for v2, that doesn't have an
46+
! allocatable member.
47+
! - InitializeClone must be called for v, that has an allocatable member.
48+
! - To avoid race conditions between InitializeClone and lastprivate, a
49+
! barrier must be present after the initializations.
50+
!CHECK-LABEL: func @_QMm1Ptest_array
51+
!CHECK: fir.call @_FortranAInitializeClone
52+
!CHECK-NEXT: omp.barrier
53+
subroutine test_array()
54+
type(x) :: v(10)
55+
type(y) :: v2(10)
56+
allocate(v(1)%x1(5))
57+
58+
!$omp parallel private(v)
59+
!$omp end parallel
60+
61+
!$omp parallel
62+
!$omp sections lastprivate(v2, v)
63+
!$omp end sections
64+
!$omp end parallel
65+
66+
!$omp parallel firstprivate(v)
67+
!$omp end parallel
68+
end subroutine
69+
70+
subroutine test_array_of_allocs()
71+
type(x), allocatable :: v(:)
72+
allocate(v(10))
73+
allocate(v(1)%x1(5))
74+
75+
!$omp parallel private(v)
76+
!$omp end parallel
77+
end subroutine
78+
79+
subroutine test_nested()
80+
type dt1
81+
integer, allocatable :: a(:)
82+
end type
83+
84+
type dt2
85+
type(dt1) :: d1
86+
end type
87+
88+
type(dt2) :: d2
89+
allocate(d2%d1%a(10))
90+
91+
!$omp parallel private(d2)
92+
!$omp end parallel
93+
end subroutine
94+
end module

0 commit comments

Comments
 (0)