Skip to content

Commit 9aeb7f0

Browse files
clementvaljeanPerierschweitzpgi
committed
[flang] Lower IO input with vector subscripts
This patch adds lowering for IO input with vector subscripts. It defines a VectorSubscriptBox class that allow representing and working with a lowered Designator containing vector subscripts while ensuring all the subscripts expression are only lowered once. This patch is part of the upstreaming effort from fir-dev branch. Reviewed By: PeteSteinfeld Differential Revision: https://reviews.llvm.org/D121806 Co-authored-by: Jean Perier <[email protected]> Co-authored-by: Eric Schweitz <[email protected]>
1 parent 6a23d27 commit 9aeb7f0

File tree

15 files changed

+1495
-75
lines changed

15 files changed

+1495
-75
lines changed

flang/include/flang/Lower/AbstractConverter.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ class AbstractConverter {
159159
/// Generate the type from a Variable
160160
virtual mlir::Type genType(const pft::Variable &) = 0;
161161

162+
/// Register a runtime derived type information object symbol to ensure its
163+
/// object will be generated as a global.
164+
virtual void registerRuntimeTypeInfo(mlir::Location loc,
165+
SymbolRef typeInfoSym) = 0;
166+
162167
//===--------------------------------------------------------------------===//
163168
// Locations
164169
//===--------------------------------------------------------------------===//

flang/include/flang/Lower/Support/Utils.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,25 @@ static Fortran::lower::SomeExpr toEvExpr(const A &x) {
5757
return Fortran::evaluate::AsGenericExpr(Fortran::common::Clone(x));
5858
}
5959

60+
template <Fortran::common::TypeCategory FROM>
61+
static Fortran::lower::SomeExpr ignoreEvConvert(
62+
const Fortran::evaluate::Convert<
63+
Fortran::evaluate::Type<Fortran::common::TypeCategory::Integer, 8>,
64+
FROM> &x) {
65+
return toEvExpr(x.left());
66+
}
67+
template <typename A>
68+
static Fortran::lower::SomeExpr ignoreEvConvert(const A &x) {
69+
return toEvExpr(x);
70+
}
71+
72+
/// A vector subscript expression may be wrapped with a cast to INTEGER*8.
73+
/// Get rid of it here so the vector can be loaded. Add it back when
74+
/// generating the elemental evaluation (inside the loop nest).
75+
inline Fortran::lower::SomeExpr
76+
ignoreEvConvert(const Fortran::evaluate::Expr<Fortran::evaluate::Type<
77+
Fortran::common::TypeCategory::Integer, 8>> &x) {
78+
return std::visit([](const auto &v) { return ignoreEvConvert(v); }, x.u);
79+
}
80+
6081
#endif // FORTRAN_LOWER_SUPPORT_UTILS_H
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
//===-- VectorSubscripts.h -- vector subscripts tools -----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// \file
10+
/// \brief Defines a compiler internal representation for lowered designators
11+
/// containing vector subscripts. This representation allows working on such
12+
/// designators in custom ways while ensuring the designator subscripts are
13+
/// only evaluated once. It is mainly intended for cases that do not fit in
14+
/// the array expression lowering framework like input IO in presence of
15+
/// vector subscripts.
16+
///
17+
//===----------------------------------------------------------------------===//
18+
19+
#ifndef FORTRAN_LOWER_VECTORSUBSCRIPTS_H
20+
#define FORTRAN_LOWER_VECTORSUBSCRIPTS_H
21+
22+
#include "flang/Optimizer/Builder/BoxValue.h"
23+
24+
namespace fir {
25+
class FirOpBuilder;
26+
}
27+
28+
namespace Fortran {
29+
30+
namespace evaluate {
31+
template <typename>
32+
class Expr;
33+
struct SomeType;
34+
} // namespace evaluate
35+
36+
namespace lower {
37+
38+
class AbstractConverter;
39+
class StatementContext;
40+
41+
/// VectorSubscriptBox is a lowered representation for any Designator<T> that
42+
/// contain at least one vector subscript.
43+
///
44+
/// A designator `x%a(i,j)%b(1:foo():1, vector, k)%c%d(m)%e1
45+
/// Is lowered into:
46+
/// - an ExtendedValue for ranked base (x%a(i,j)%b)
47+
/// - mlir:Values and ExtendedValues for the triplet, vector subscript and
48+
/// scalar subscripts of the ranked array reference (1:foo():1, vector, k)
49+
/// - a list of fir.field_index and scalar integers mlir::Value for the
50+
/// component
51+
/// path at the right of the ranked array ref (%c%d(m)%e).
52+
///
53+
/// This representation allows later creating loops over the designator elements
54+
/// and fir.array_coor to get the element addresses without re-evaluating any
55+
/// sub-expressions.
56+
class VectorSubscriptBox {
57+
public:
58+
/// Type of the callbacks that can be passed to work with the element
59+
/// addresses.
60+
using ElementalGenerator = std::function<void(const fir::ExtendedValue &)>;
61+
using ElementalGeneratorWithBoolReturn =
62+
std::function<mlir::Value(const fir::ExtendedValue &)>;
63+
struct LoweredVectorSubscript {
64+
LoweredVectorSubscript(fir::ExtendedValue &&vector, mlir::Value size)
65+
: vector{std::move(vector)}, size{size} {}
66+
fir::ExtendedValue vector;
67+
// Vector size, guaranteed to be of indexType.
68+
mlir::Value size;
69+
};
70+
struct LoweredTriplet {
71+
// Triplets value, guaranteed to be of indexType.
72+
mlir::Value lb;
73+
mlir::Value ub;
74+
mlir::Value stride;
75+
};
76+
using LoweredSubscript =
77+
std::variant<mlir::Value, LoweredTriplet, LoweredVectorSubscript>;
78+
using MaybeSubstring = llvm::SmallVector<mlir::Value, 2>;
79+
VectorSubscriptBox(
80+
fir::ExtendedValue &&loweredBase,
81+
llvm::SmallVector<LoweredSubscript, 16> &&loweredSubscripts,
82+
llvm::SmallVector<mlir::Value> &&componentPath,
83+
MaybeSubstring substringBounds, mlir::Type elementType)
84+
: loweredBase{std::move(loweredBase)}, loweredSubscripts{std::move(
85+
loweredSubscripts)},
86+
componentPath{std::move(componentPath)},
87+
substringBounds{substringBounds}, elementType{elementType} {};
88+
89+
/// Loop over the elements described by the VectorSubscriptBox, and call
90+
/// \p elementalGenerator inside the loops with the element addresses.
91+
void loopOverElements(fir::FirOpBuilder &builder, mlir::Location loc,
92+
const ElementalGenerator &elementalGenerator);
93+
94+
/// Loop over the elements described by the VectorSubscriptBox while a
95+
/// condition is true, and call \p elementalGenerator inside the loops with
96+
/// the element addresses. The initial condition value is \p initialCondition,
97+
/// and then it is the result of \p elementalGenerator. The value of the
98+
/// condition after the loops is returned.
99+
mlir::Value loopOverElementsWhile(
100+
fir::FirOpBuilder &builder, mlir::Location loc,
101+
const ElementalGeneratorWithBoolReturn &elementalGenerator,
102+
mlir::Value initialCondition);
103+
104+
/// Return the type of the elements of the array section.
105+
mlir::Type getElementType() { return elementType; }
106+
107+
private:
108+
/// Common implementation for DoLoop and IterWhile loop creations.
109+
template <typename LoopType, typename Generator>
110+
mlir::Value loopOverElementsBase(fir::FirOpBuilder &builder,
111+
mlir::Location loc,
112+
const Generator &elementalGenerator,
113+
mlir::Value initialCondition);
114+
/// Create sliceOp for the designator.
115+
mlir::Value createSlice(fir::FirOpBuilder &builder, mlir::Location loc);
116+
117+
/// Create ExtendedValue the element inside the loop.
118+
fir::ExtendedValue getElementAt(fir::FirOpBuilder &builder,
119+
mlir::Location loc, mlir::Value shape,
120+
mlir::Value slice,
121+
mlir::ValueRange inductionVariables);
122+
123+
/// Generate the [lb, ub, step] to loop over the section (in loop order, not
124+
/// Fortran dimension order).
125+
llvm::SmallVector<std::tuple<mlir::Value, mlir::Value, mlir::Value>>
126+
genLoopBounds(fir::FirOpBuilder &builder, mlir::Location loc);
127+
128+
/// Lowered base of the ranked array ref.
129+
fir::ExtendedValue loweredBase;
130+
/// Subscripts values of the rank arrayRef part.
131+
llvm::SmallVector<LoweredSubscript, 16> loweredSubscripts;
132+
/// Scalar subscripts and components at the right of the ranked
133+
/// array ref part of any.
134+
llvm::SmallVector<mlir::Value> componentPath;
135+
/// List of substring bounds if this is a substring (only the lower bound if
136+
/// the upper is implicit).
137+
MaybeSubstring substringBounds;
138+
/// Type of the elements described by this array section.
139+
mlir::Type elementType;
140+
};
141+
142+
/// Lower \p expr, that must be an designator containing vector subscripts, to a
143+
/// VectorSubscriptBox representation. This causes evaluation of all the
144+
/// subscripts. Any required clean-ups from subscript expression are added to \p
145+
/// stmtCtx.
146+
VectorSubscriptBox genVectorSubscriptBox(
147+
mlir::Location loc, Fortran::lower::AbstractConverter &converter,
148+
Fortran::lower::StatementContext &stmtCtx,
149+
const Fortran::evaluate::Expr<Fortran::evaluate::SomeType> &expr);
150+
151+
} // namespace lower
152+
} // namespace Fortran
153+
154+
#endif // FORTRAN_LOWER_VECTORSUBSCRIPTS_H

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ std::unique_ptr<mlir::Pass> createMemoryAllocationPass();
3838
std::unique_ptr<mlir::Pass>
3939
createMemoryAllocationPass(bool dynOnHeap, std::size_t maxStackSize);
4040
std::unique_ptr<mlir::Pass> createAnnotateConstantOperandsPass();
41+
std::unique_ptr<mlir::Pass> createSimplifyRegionLitePass();
4142

4243
// declarative passes
4344
#define GEN_PASS_REGISTRATION

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,4 +188,12 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::FuncOp"> {
188188
let constructor = "::fir::createMemoryAllocationPass()";
189189
}
190190

191+
def SimplifyRegionLite : Pass<"simplify-region-lite", "mlir::ModuleOp"> {
192+
let summary = "Region simplification";
193+
let description = [{
194+
Run region DCE and erase unreachable blocks in regions.
195+
}];
196+
let constructor = "::fir::createSimplifyRegionLitePass()";
197+
}
198+
191199
#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES

flang/include/flang/Tools/CLOptions.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ inline void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm) {
143143
fir::addAVC(pm);
144144
pm.addNestedPass<mlir::FuncOp>(fir::createCharacterConversionPass());
145145
pm.addPass(mlir::createCanonicalizerPass(config));
146+
pm.addPass(fir::createSimplifyRegionLitePass());
146147
fir::addMemoryAllocationOpt(pm);
147148

148149
// The default inliner pass adds the canonicalizer pass with the default
@@ -157,6 +158,7 @@ inline void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm) {
157158
pm.addPass(mlir::createConvertSCFToCFPass());
158159

159160
pm.addPass(mlir::createCanonicalizerPass(config));
161+
pm.addPass(fir::createSimplifyRegionLitePass());
160162
}
161163

162164
#if !defined(FLANG_EXCLUDE_CODEGEN)

flang/lib/Lower/Bridge.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,68 @@ static llvm::cl::opt<bool> dumpBeforeFir(
4949
"fdebug-dump-pre-fir", llvm::cl::init(false),
5050
llvm::cl::desc("dump the Pre-FIR tree prior to FIR generation"));
5151

52+
namespace {
53+
/// Helper class to generate the runtime type info global data. This data
54+
/// is required to describe the derived type to the runtime so that it can
55+
/// operate over it. It must be ensured this data will be generated for every
56+
/// derived type lowered in the current translated unit. However, this data
57+
/// cannot be generated before FuncOp have been created for functions since the
58+
/// initializers may take their address (e.g for type bound procedures). This
59+
/// class allows registering all the required runtime type info while it is not
60+
/// possible to create globals, and to generate this data after function
61+
/// lowering.
62+
class RuntimeTypeInfoConverter {
63+
/// Store the location and symbols of derived type info to be generated.
64+
/// The location of the derived type instantiation is also stored because
65+
/// runtime type descriptor symbol are compiler generated and cannot be mapped
66+
/// to user code on their own.
67+
struct TypeInfoSymbol {
68+
Fortran::semantics::SymbolRef symbol;
69+
mlir::Location loc;
70+
};
71+
72+
public:
73+
void registerTypeInfoSymbol(Fortran::lower::AbstractConverter &converter,
74+
mlir::Location loc,
75+
Fortran::semantics::SymbolRef typeInfoSym) {
76+
if (seen.contains(typeInfoSym))
77+
return;
78+
seen.insert(typeInfoSym);
79+
if (!skipRegistration) {
80+
registeredTypeInfoSymbols.emplace_back(TypeInfoSymbol{typeInfoSym, loc});
81+
return;
82+
}
83+
// Once the registration is closed, symbols cannot be added to the
84+
// registeredTypeInfoSymbols list because it may be iterated over.
85+
// However, after registration is closed, it is safe to directly generate
86+
// the globals because all FuncOps whose addresses may be required by the
87+
// initializers have been generated.
88+
Fortran::lower::createRuntimeTypeInfoGlobal(converter, loc,
89+
typeInfoSym.get());
90+
}
91+
92+
void createTypeInfoGlobals(Fortran::lower::AbstractConverter &converter) {
93+
skipRegistration = true;
94+
for (const TypeInfoSymbol &info : registeredTypeInfoSymbols)
95+
Fortran::lower::createRuntimeTypeInfoGlobal(converter, info.loc,
96+
info.symbol.get());
97+
registeredTypeInfoSymbols.clear();
98+
}
99+
100+
private:
101+
/// Store the runtime type descriptors that will be required for the
102+
/// derived type that have been converted to FIR derived types.
103+
llvm::SmallVector<TypeInfoSymbol> registeredTypeInfoSymbols;
104+
/// Create derived type runtime info global immediately without storing the
105+
/// symbol in registeredTypeInfoSymbols.
106+
bool skipRegistration = false;
107+
/// Track symbols symbols processed during and after the registration
108+
/// to avoid infinite loops between type conversions and global variable
109+
/// creation.
110+
llvm::SmallSetVector<Fortran::semantics::SymbolRef, 64> seen;
111+
};
112+
} // namespace
113+
52114
//===----------------------------------------------------------------------===//
53115
// FirConverter
54116
//===----------------------------------------------------------------------===//
@@ -101,6 +163,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
101163
},
102164
u);
103165
}
166+
167+
/// Once all the code has been translated, create runtime type info
168+
/// global data structure for the derived types that have been
169+
/// processed.
170+
createGlobalOutsideOfFunctionLowering(
171+
[&]() { runtimeTypeInfoConverter.createTypeInfoGlobals(*this); });
104172
}
105173

106174
/// Declare a function.
@@ -689,6 +757,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
689757
hostAssocTuple = val;
690758
}
691759

760+
void registerRuntimeTypeInfo(
761+
mlir::Location loc,
762+
Fortran::lower::SymbolRef typeInfoSym) override final {
763+
runtimeTypeInfoConverter.registerTypeInfoSymbol(*this, loc, typeInfoSym);
764+
}
765+
692766
private:
693767
FirConverter() = delete;
694768
FirConverter(const FirConverter &) = delete;
@@ -2319,6 +2393,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
23192393
Fortran::lower::pft::Evaluation *evalPtr = nullptr;
23202394
Fortran::lower::SymMap localSymbols;
23212395
Fortran::parser::CharBlock currentPosition;
2396+
RuntimeTypeInfoConverter runtimeTypeInfoConverter;
23222397

23232398
/// Tuple of host assoicated variables.
23242399
mlir::Value hostAssocTuple;

flang/lib/Lower/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_flang_library(FortranLower
2121
PFTBuilder.cpp
2222
Runtime.cpp
2323
SymbolMap.cpp
24+
VectorSubscripts.cpp
2425

2526
DEPENDS
2627
FIRDialect

0 commit comments

Comments
 (0)