Skip to content

Commit 0446bfc

Browse files
committed
[flang][hlfir] Codegen of hlfir.region_assign where LHS conflicts
When the analysis of hlfir.region_assign determined that the LHS region evaluation may be impacted by the assignment effects, all LHS must be fully evaluated and saved before any assignment is done. This patch adds TemporaryStorage variants to save address, including vector subscripted entities addresses whose shape must be saved. It uses the DescriptorStack runtime to deal with complex cases inside forall. For the sake of simplicity, this is also used for vector subscripted LHS outside of foralls (each element address is saved as a descriptor on this stack. This is a bit suboptimal, but it is a safe start that will work with all kinds of type (polymorphic, PDTs...) without further work). Another approach would be to saved only the values that are conflicting in the LHS computation, but this would require a much more complex analysis of the LHS region DAG. Differential Revision: https://reviews.llvm.org/D154057
1 parent 1233e2e commit 0446bfc

File tree

6 files changed

+618
-19
lines changed

6 files changed

+618
-19
lines changed

flang/include/flang/Optimizer/Builder/Runtime/TemporaryStack.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,16 @@ void genValueAt(mlir::Location loc, fir::FirOpBuilder &builder,
3030
void genDestroyValueStack(mlir::Location loc, fir::FirOpBuilder &builder,
3131
mlir::Value opaquePtr);
3232

33+
mlir::Value genCreateDescriptorStack(mlir::Location loc,
34+
fir::FirOpBuilder &builder);
35+
36+
void genPushDescriptor(mlir::Location loc, fir::FirOpBuilder &builder,
37+
mlir::Value opaquePtr, mlir::Value boxValue);
38+
void genDescriptorAt(mlir::Location loc, fir::FirOpBuilder &builder,
39+
mlir::Value opaquePtr, mlir::Value i,
40+
mlir::Value retValueBox);
41+
42+
void genDestroyDescriptorStack(mlir::Location loc, fir::FirOpBuilder &builder,
43+
mlir::Value opaquePtr);
3344
} // namespace fir::runtime
3445
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_TEMPORARYSTACK_H

flang/include/flang/Optimizer/Builder/TemporaryStorage.h

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,30 @@ class SimpleCopy {
120120
hlfir::AssociateOp copy;
121121
};
122122

123+
/// Structure to keep track of a simple mlir::Value. This is useful
124+
/// when a value does not need an in memory copy because it is
125+
/// already saved in an SSA value that will be accessible at the fetching
126+
/// point.
127+
class SSARegister {
128+
public:
129+
SSARegister(){};
130+
131+
void pushValue(mlir::Location loc, fir::FirOpBuilder &builder,
132+
mlir::Value value) {
133+
ssaRegister = value;
134+
}
135+
void resetFetchPosition(mlir::Location loc, fir::FirOpBuilder &builder){};
136+
mlir::Value fetch(mlir::Location loc, fir::FirOpBuilder &builder) {
137+
return ssaRegister;
138+
}
139+
void destroy(mlir::Location loc, fir::FirOpBuilder &builder) {}
140+
bool canBeFetchedAfterPush() const { return true; }
141+
142+
public:
143+
/// Temporary storage for the copy.
144+
mlir::Value ssaRegister;
145+
};
146+
123147
/// Data structure to stack any kind of values with the same static type and
124148
/// rank. Each value may have different type parameters, bounds, and dynamic
125149
/// type. Fetching value N will return a value with the same dynamic type,
@@ -150,6 +174,61 @@ class AnyValueStack {
150174
mlir::Value retValueBox;
151175
};
152176

177+
/// Data structure to stack any kind of variables with the same static type and
178+
/// rank. Each variable may have different type parameters, bounds, and dynamic
179+
/// type. Fetching variable N will return a variable with the same address,
180+
/// dynamic type, bounds, and type parameters as the Nth variable that was
181+
/// pushed. It is implemented using runtime.
182+
class AnyVariableStack {
183+
public:
184+
AnyVariableStack(mlir::Location loc, fir::FirOpBuilder &builder,
185+
mlir::Type valueStaticType);
186+
187+
void pushValue(mlir::Location loc, fir::FirOpBuilder &builder,
188+
mlir::Value value);
189+
void resetFetchPosition(mlir::Location loc, fir::FirOpBuilder &builder);
190+
mlir::Value fetch(mlir::Location loc, fir::FirOpBuilder &builder);
191+
void destroy(mlir::Location loc, fir::FirOpBuilder &builder);
192+
bool canBeFetchedAfterPush() const { return true; }
193+
194+
private:
195+
/// Keep the original variable type.
196+
mlir::Type variableStaticType;
197+
/// Runtime cookie created by the runtime. It is a pointer to an opaque
198+
/// runtime data structure that manages the stack.
199+
mlir::Value opaquePtr;
200+
/// Counter to keep track of the fetching position.
201+
Counter counter;
202+
/// Pointer box passed to the runtime when fetching the values.
203+
mlir::Value retValueBox;
204+
};
205+
206+
class TemporaryStorage;
207+
208+
/// Data structure to stack vector subscripted entity shape and
209+
/// element addresses. AnyVariableStack allows saving vector subscripted
210+
/// entities element addresses, but when saving several vector subscripted
211+
/// entities on a stack, and if the context does not allow retrieving the
212+
/// vector subscript entities shapes, these shapes must be saved too.
213+
class AnyVectorSubscriptStack : public AnyVariableStack {
214+
public:
215+
AnyVectorSubscriptStack(mlir::Location loc, fir::FirOpBuilder &builder,
216+
mlir::Type valueStaticType,
217+
bool shapeCanBeSavedAsRegister, int rank);
218+
void pushShape(mlir::Location loc, fir::FirOpBuilder &builder,
219+
mlir::Value shape);
220+
void resetFetchPosition(mlir::Location loc, fir::FirOpBuilder &builder);
221+
mlir::Value fetchShape(mlir::Location loc, fir::FirOpBuilder &builder);
222+
void destroy(mlir::Location loc, fir::FirOpBuilder &builder);
223+
bool canBeFetchedAfterPush() const { return true; }
224+
225+
private:
226+
std::unique_ptr<TemporaryStorage> shapeTemp;
227+
// If the shape is saved inside a descriptor (as extents),
228+
// keep track of the descriptor type.
229+
std::optional<mlir::Type> boxType;
230+
};
231+
153232
/// Generic wrapper over the different sorts of temporary storages.
154233
class TemporaryStorage {
155234
public:
@@ -178,8 +257,15 @@ class TemporaryStorage {
178257
impl);
179258
}
180259

260+
template <typename T>
261+
T &cast() {
262+
return std::get<T>(impl);
263+
}
264+
181265
private:
182-
std::variant<HomogeneousScalarStack, SimpleCopy, AnyValueStack> impl;
266+
std::variant<HomogeneousScalarStack, SimpleCopy, SSARegister, AnyValueStack,
267+
AnyVariableStack, AnyVectorSubscriptStack>
268+
impl;
183269
};
184270
} // namespace fir::factory
185271
#endif // FORTRAN_OPTIMIZER_BUILDER_TEMPORARYSTORAGE_H

flang/lib/Optimizer/Builder/Runtime/TemporaryStack.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,52 @@ void fir::runtime::genDestroyValueStack(mlir::Location loc,
5656
auto args = fir::runtime::createArguments(builder, loc, funcType, opaquePtr);
5757
builder.create<fir::CallOp>(loc, func, args);
5858
}
59+
60+
mlir::Value fir::runtime::genCreateDescriptorStack(mlir::Location loc,
61+
fir::FirOpBuilder &builder) {
62+
mlir::func::FuncOp func =
63+
fir::runtime::getRuntimeFunc<mkRTKey(CreateDescriptorStack)>(loc,
64+
builder);
65+
mlir::FunctionType funcType = func.getFunctionType();
66+
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
67+
mlir::Value sourceLine =
68+
fir::factory::locationToLineNo(builder, loc, funcType.getInput(1));
69+
auto args = fir::runtime::createArguments(builder, loc, funcType, sourceFile,
70+
sourceLine);
71+
return builder.create<fir::CallOp>(loc, func, args).getResult(0);
72+
}
73+
74+
void fir::runtime::genPushDescriptor(mlir::Location loc,
75+
fir::FirOpBuilder &builder,
76+
mlir::Value opaquePtr,
77+
mlir::Value boxDescriptor) {
78+
mlir::func::FuncOp func =
79+
fir::runtime::getRuntimeFunc<mkRTKey(PushDescriptor)>(loc, builder);
80+
mlir::FunctionType funcType = func.getFunctionType();
81+
auto args = fir::runtime::createArguments(builder, loc, funcType, opaquePtr,
82+
boxDescriptor);
83+
builder.create<fir::CallOp>(loc, func, args);
84+
}
85+
86+
void fir::runtime::genDescriptorAt(mlir::Location loc,
87+
fir::FirOpBuilder &builder,
88+
mlir::Value opaquePtr, mlir::Value i,
89+
mlir::Value retDescriptorBox) {
90+
mlir::func::FuncOp func =
91+
fir::runtime::getRuntimeFunc<mkRTKey(DescriptorAt)>(loc, builder);
92+
mlir::FunctionType funcType = func.getFunctionType();
93+
auto args = fir::runtime::createArguments(builder, loc, funcType, opaquePtr,
94+
i, retDescriptorBox);
95+
builder.create<fir::CallOp>(loc, func, args);
96+
}
97+
98+
void fir::runtime::genDestroyDescriptorStack(mlir::Location loc,
99+
fir::FirOpBuilder &builder,
100+
mlir::Value opaquePtr) {
101+
mlir::func::FuncOp func =
102+
fir::runtime::getRuntimeFunc<mkRTKey(DestroyDescriptorStack)>(loc,
103+
builder);
104+
mlir::FunctionType funcType = func.getFunctionType();
105+
auto args = fir::runtime::createArguments(builder, loc, funcType, opaquePtr);
106+
builder.create<fir::CallOp>(loc, func, args);
107+
}

flang/lib/Optimizer/Builder/TemporaryStorage.cpp

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,127 @@ void fir::factory::AnyValueStack::destroy(mlir::Location loc,
231231
fir::FirOpBuilder &builder) {
232232
fir::runtime::genDestroyValueStack(loc, builder, opaquePtr);
233233
}
234+
235+
//===----------------------------------------------------------------------===//
236+
// fir::factory::AnyVariableStack implementation.
237+
//===----------------------------------------------------------------------===//
238+
239+
fir::factory::AnyVariableStack::AnyVariableStack(mlir::Location loc,
240+
fir::FirOpBuilder &builder,
241+
mlir::Type variableStaticType)
242+
: variableStaticType{variableStaticType},
243+
counter{loc, builder,
244+
builder.createIntegerConstant(loc, builder.getI64Type(), 0),
245+
/*stackThroughLoops=*/true} {
246+
opaquePtr = fir::runtime::genCreateDescriptorStack(loc, builder);
247+
mlir::Type storageType =
248+
hlfir::getFortranElementOrSequenceType(variableStaticType);
249+
mlir::Type ptrType = fir::PointerType::get(storageType);
250+
mlir::Type boxType;
251+
if (hlfir::isPolymorphicType(variableStaticType))
252+
boxType = fir::ClassType::get(ptrType);
253+
else
254+
boxType = fir::BoxType::get(ptrType);
255+
retValueBox = builder.createTemporary(loc, boxType);
256+
}
257+
258+
void fir::factory::AnyVariableStack::pushValue(mlir::Location loc,
259+
fir::FirOpBuilder &builder,
260+
mlir::Value variable) {
261+
hlfir::Entity entity{variable};
262+
mlir::Type storageElementType =
263+
hlfir::getFortranElementType(retValueBox.getType());
264+
auto [box, maybeCleanUp] =
265+
hlfir::convertToBox(loc, builder, entity, storageElementType);
266+
fir::runtime::genPushDescriptor(loc, builder, opaquePtr, fir::getBase(box));
267+
if (maybeCleanUp)
268+
(*maybeCleanUp)();
269+
}
270+
271+
void fir::factory::AnyVariableStack::resetFetchPosition(
272+
mlir::Location loc, fir::FirOpBuilder &builder) {
273+
counter.reset(loc, builder);
274+
}
275+
276+
mlir::Value fir::factory::AnyVariableStack::fetch(mlir::Location loc,
277+
fir::FirOpBuilder &builder) {
278+
mlir::Value indexValue = counter.getAndIncrementIndex(loc, builder);
279+
fir::runtime::genDescriptorAt(loc, builder, opaquePtr, indexValue,
280+
retValueBox);
281+
hlfir::Entity retBox{builder.create<fir::LoadOp>(loc, retValueBox)};
282+
// The runtime always tracks variable as address, but the form of the variable
283+
// that was saved may be different (raw address, fir.boxchar), ensure
284+
// the returned variable has the same form of the one that was saved.
285+
if (mlir::isa<fir::BaseBoxType>(variableStaticType))
286+
return builder.createConvert(loc, variableStaticType, retBox);
287+
if (mlir::isa<fir::BoxCharType>(variableStaticType))
288+
return hlfir::genVariableBoxChar(loc, builder, retBox);
289+
mlir::Value rawAddr = genVariableRawAddress(loc, builder, retBox);
290+
return builder.createConvert(loc, variableStaticType, rawAddr);
291+
}
292+
293+
void fir::factory::AnyVariableStack::destroy(mlir::Location loc,
294+
fir::FirOpBuilder &builder) {
295+
fir::runtime::genDestroyDescriptorStack(loc, builder, opaquePtr);
296+
}
297+
298+
//===----------------------------------------------------------------------===//
299+
// fir::factory::AnyVectorSubscriptStack implementation.
300+
//===----------------------------------------------------------------------===//
301+
302+
fir::factory::AnyVectorSubscriptStack::AnyVectorSubscriptStack(
303+
mlir::Location loc, fir::FirOpBuilder &builder,
304+
mlir::Type variableStaticType, bool shapeCanBeSavedAsRegister, int rank)
305+
: AnyVariableStack{loc, builder, variableStaticType} {
306+
if (shapeCanBeSavedAsRegister) {
307+
shapeTemp =
308+
std::unique_ptr<TemporaryStorage>(new TemporaryStorage{SSARegister{}});
309+
return;
310+
}
311+
// The shape will be tracked as the dimension inside a descriptor because
312+
// that is the easiest from a lowering point of view, and this is an
313+
// edge case situation that will probably not very well be exercised.
314+
mlir::Type type =
315+
fir::BoxType::get(builder.getVarLenSeqTy(builder.getI32Type(), rank));
316+
boxType = type;
317+
shapeTemp = std::unique_ptr<TemporaryStorage>(
318+
new TemporaryStorage{AnyVariableStack{loc, builder, type}});
319+
}
320+
321+
void fir::factory::AnyVectorSubscriptStack::pushShape(
322+
mlir::Location loc, fir::FirOpBuilder &builder, mlir::Value shape) {
323+
if (boxType) {
324+
// The shape is saved as a dimensions inside a descriptors.
325+
mlir::Type refType = fir::ReferenceType::get(
326+
hlfir::getFortranElementOrSequenceType(*boxType));
327+
mlir::Value null = builder.createNullConstant(loc, refType);
328+
mlir::Value descriptor =
329+
builder.create<fir::EmboxOp>(loc, *boxType, null, shape);
330+
shapeTemp->pushValue(loc, builder, descriptor);
331+
return;
332+
}
333+
// Otherwise, simply keep track of the fir.shape itself, it is invariant.
334+
shapeTemp->cast<SSARegister>().pushValue(loc, builder, shape);
335+
}
336+
337+
void fir::factory::AnyVectorSubscriptStack::resetFetchPosition(
338+
mlir::Location loc, fir::FirOpBuilder &builder) {
339+
static_cast<AnyVariableStack *>(this)->resetFetchPosition(loc, builder);
340+
shapeTemp->resetFetchPosition(loc, builder);
341+
}
342+
343+
mlir::Value
344+
fir::factory::AnyVectorSubscriptStack::fetchShape(mlir::Location loc,
345+
fir::FirOpBuilder &builder) {
346+
if (boxType) {
347+
hlfir::Entity descriptor{shapeTemp->fetch(loc, builder)};
348+
return hlfir::genShape(loc, builder, descriptor);
349+
}
350+
return shapeTemp->cast<SSARegister>().fetch(loc, builder);
351+
}
352+
353+
void fir::factory::AnyVectorSubscriptStack::destroy(
354+
mlir::Location loc, fir::FirOpBuilder &builder) {
355+
static_cast<AnyVariableStack *>(this)->destroy(loc, builder);
356+
shapeTemp->destroy(loc, builder);
357+
}

0 commit comments

Comments
 (0)