Skip to content

Commit 605c706

Browse files
author
git apple-llvm automerger
committed
Merge commit 'ed07412888e2' from llvm.org/main into next
2 parents 72905d6 + ed07412 commit 605c706

File tree

7 files changed

+329
-21
lines changed

7 files changed

+329
-21
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//===-- LLVMInsertChainFolder.h -- insertvalue chain folder ----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Helper to fold LLVM dialect llvm.insertvalue chain representing constants
10+
// into an Attribute representation.
11+
// This sits in Flang because it is incomplete and tailored for flang needs.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "llvm/Support/LogicalResult.h"
16+
17+
namespace mlir {
18+
class Attribute;
19+
class OpBuilder;
20+
class Value;
21+
} // namespace mlir
22+
23+
namespace fir {
24+
25+
/// Attempt to fold an llvm.insertvalue chain into an attribute representation
26+
/// suitable as llvm.constant operand. The returned value will be llvm::Failure
27+
/// if this is not an llvm.insertvalue result or if the chain is not a constant,
28+
/// or cannot be represented as an Attribute. The operations are not deleted,
29+
/// but some llvm.insertvalue value operands may be folded with the builder on
30+
/// the way.
31+
llvm::FailureOr<mlir::Attribute>
32+
tryFoldingLLVMInsertChain(mlir::Value insertChainResult,
33+
mlir::OpBuilder &builder);
34+
} // namespace fir

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2129,6 +2129,11 @@ def fir_InsertOnRangeOp : fir_OneResultOp<"insert_on_range", [NoMemoryEffect]> {
21292129
$seq `,` $val custom<CustomRangeSubscript>($coor) attr-dict `:` functional-type(operands, results)
21302130
}];
21312131

2132+
let extraClassDeclaration = [{
2133+
/// Is this insert_on_range inserting on all the values of the result type?
2134+
bool isFullRange();
2135+
}];
2136+
21322137
let hasVerifier = 1;
21332138
}
21342139

flang/lib/Optimizer/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_flang_library(FIRCodeGen
33
CodeGen.cpp
44
CodeGenOpenMP.cpp
55
FIROpPatterns.cpp
6+
LLVMInsertChainFolder.cpp
67
LowerRepackArrays.cpp
78
PreCGRewrite.cpp
89
TBAABuilder.cpp

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "flang/Optimizer/CodeGen/CodeGenOpenMP.h"
1616
#include "flang/Optimizer/CodeGen/FIROpPatterns.h"
17+
#include "flang/Optimizer/CodeGen/LLVMInsertChainFolder.h"
1718
#include "flang/Optimizer/CodeGen/TypeConverter.h"
1819
#include "flang/Optimizer/Dialect/FIRAttr.h"
1920
#include "flang/Optimizer/Dialect/FIRCG/CGOps.h"
@@ -2412,15 +2413,39 @@ struct InsertOnRangeOpConversion
24122413
doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor,
24132414
mlir::ConversionPatternRewriter &rewriter) const override {
24142415

2415-
llvm::SmallVector<std::int64_t> dims;
2416-
auto type = adaptor.getOperands()[0].getType();
2416+
auto arrayType = adaptor.getSeq().getType();
24172417

24182418
// Iteratively extract the array dimensions from the type.
2419+
llvm::SmallVector<std::int64_t> dims;
2420+
mlir::Type type = arrayType;
24192421
while (auto t = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(type)) {
24202422
dims.push_back(t.getNumElements());
24212423
type = t.getElementType();
24222424
}
24232425

2426+
// Avoid generating long insert chain that are very slow to fold back
2427+
// (which is required in globals when later generating LLVM IR). Attempt to
2428+
// fold the inserted element value to an attribute and build an ArrayAttr
2429+
// for the resulting array.
2430+
if (range.isFullRange()) {
2431+
llvm::FailureOr<mlir::Attribute> cst =
2432+
fir::tryFoldingLLVMInsertChain(adaptor.getVal(), rewriter);
2433+
if (llvm::succeeded(cst)) {
2434+
mlir::Attribute dimVal = *cst;
2435+
for (auto dim : llvm::reverse(dims)) {
2436+
// Use std::vector in case the number of elements is big.
2437+
std::vector<mlir::Attribute> elements(dim, dimVal);
2438+
dimVal = mlir::ArrayAttr::get(range.getContext(), elements);
2439+
}
2440+
// Replace insert chain with constant.
2441+
rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(range, arrayType,
2442+
dimVal);
2443+
return mlir::success();
2444+
}
2445+
}
2446+
2447+
// The inserted value cannot be folded to an attribute, turn the
2448+
// insert_range into an llvm.insertvalue chain.
24242449
llvm::SmallVector<std::int64_t> lBounds;
24252450
llvm::SmallVector<std::int64_t> uBounds;
24262451

@@ -2434,8 +2459,8 @@ struct InsertOnRangeOpConversion
24342459

24352460
auto &subscripts = lBounds;
24362461
auto loc = range.getLoc();
2437-
mlir::Value lastOp = adaptor.getOperands()[0];
2438-
mlir::Value insertVal = adaptor.getOperands()[1];
2462+
mlir::Value lastOp = adaptor.getSeq();
2463+
mlir::Value insertVal = adaptor.getVal();
24392464

24402465
while (subscripts != uBounds) {
24412466
lastOp = rewriter.create<mlir::LLVM::InsertValueOp>(
@@ -3131,7 +3156,7 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
31313156
// initialization is on the full range.
31323157
auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>();
31333158
for (auto insertOp : insertOnRangeOps) {
3134-
if (isFullRange(insertOp.getCoor(), insertOp.getType())) {
3159+
if (insertOp.isFullRange()) {
31353160
auto seqTyAttr = convertType(insertOp.getType());
31363161
auto *op = insertOp.getVal().getDefiningOp();
31373162
auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op);
@@ -3161,22 +3186,7 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
31613186
return mlir::success();
31623187
}
31633188

3164-
bool isFullRange(mlir::DenseIntElementsAttr indexes,
3165-
fir::SequenceType seqTy) const {
3166-
auto extents = seqTy.getShape();
3167-
if (indexes.size() / 2 != static_cast<int64_t>(extents.size()))
3168-
return false;
3169-
auto cur_index = indexes.value_begin<int64_t>();
3170-
for (unsigned i = 0; i < indexes.size(); i += 2) {
3171-
if (*(cur_index++) != 0)
3172-
return false;
3173-
if (*(cur_index++) != extents[i / 2] - 1)
3174-
return false;
3175-
}
3176-
return true;
3177-
}
3178-
3179-
// TODO: String comparaison should be avoided. Replace linkName with an
3189+
// TODO: String comparisons should be avoided. Replace linkName with an
31803190
// enumeration.
31813191
mlir::LLVM::Linkage
31823192
convertLinkage(std::optional<llvm::StringRef> optLinkage) const {
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
//===-- LLVMInsertChainFolder.cpp -----------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Optimizer/CodeGen/LLVMInsertChainFolder.h"
10+
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
11+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12+
#include "mlir/IR/Builders.h"
13+
#include "llvm/Support/Debug.h"
14+
15+
#define DEBUG_TYPE "flang-insert-folder"
16+
17+
#include <deque>
18+
19+
namespace {
20+
// Helper class to construct the attribute elements of an aggregate value being
21+
// folded without creating a full mlir::Attribute representation for each step
22+
// of the insert value chain, which would both be expensive in terms of
23+
// compilation time and memory (since the intermediate Attribute would survive,
24+
// unused, inside the mlir context).
25+
class InsertChainBackwardFolder {
26+
// Type for the current value of an element of the aggregate value being
27+
// constructed by the insert chain.
28+
// At any point of the insert chain, the value of an element is either:
29+
// - nullptr: not yet known, the insert has not yet been seen.
30+
// - an mlir::Attribute: the element is fully defined.
31+
// - a nested InsertChainBackwardFolder: the element is itself an aggregate
32+
// and its sub-elements have been partially defined (insert with mutliple
33+
// indices have been seen).
34+
35+
// The insertion folder assumes backward walk of the insert chain. Once an
36+
// element or sub-element has been defined, it is not overriden by new
37+
// insertions (last insert wins).
38+
using InFlightValue =
39+
llvm::PointerUnion<mlir::Attribute, InsertChainBackwardFolder *>;
40+
41+
public:
42+
InsertChainBackwardFolder(
43+
mlir::Type type, std::deque<InsertChainBackwardFolder> *folderStorage)
44+
: values(getNumElements(type), mlir::Attribute{}),
45+
folderStorage{folderStorage}, type{type} {}
46+
47+
/// Push
48+
bool pushValue(mlir::Attribute val, llvm::ArrayRef<int64_t> at);
49+
50+
mlir::Attribute finalize(mlir::Attribute defaultFieldValue);
51+
52+
private:
53+
static int64_t getNumElements(mlir::Type type) {
54+
if (auto structTy =
55+
llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type))
56+
return structTy.getBody().size();
57+
if (auto arrayTy =
58+
llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type))
59+
return arrayTy.getNumElements();
60+
return 0;
61+
}
62+
63+
static mlir::Type getSubElementType(mlir::Type type, int64_t field) {
64+
if (auto arrayTy =
65+
llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type))
66+
return arrayTy.getElementType();
67+
if (auto structTy =
68+
llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type))
69+
return structTy.getBody()[field];
70+
return nullptr;
71+
}
72+
73+
// Current element value of the aggregate value being built.
74+
llvm::SmallVector<InFlightValue> values;
75+
// std::deque is used to allocate storage for nested list and guarantee the
76+
// stability of the InsertChainBackwardFolder* used as element value.
77+
std::deque<InsertChainBackwardFolder> *folderStorage;
78+
// Type of the aggregate value being built.
79+
mlir::Type type;
80+
};
81+
} // namespace
82+
83+
// Helper to fold the value being inserted by an llvm.insert_value.
84+
// This may call tryFoldingLLVMInsertChain if the value is an aggregate and
85+
// was itself constructed by a different insert chain.
86+
// Returns a nullptr Attribute if the value could not be folded.
87+
static mlir::Attribute getAttrIfConstant(mlir::Value val,
88+
mlir::OpBuilder &rewriter) {
89+
if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>())
90+
return cst.getValue();
91+
if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) {
92+
llvm::FailureOr<mlir::Attribute> attr =
93+
fir::tryFoldingLLVMInsertChain(val, rewriter);
94+
if (succeeded(attr))
95+
return *attr;
96+
return nullptr;
97+
}
98+
if (val.getDefiningOp<mlir::LLVM::ZeroOp>())
99+
return mlir::LLVM::ZeroAttr::get(val.getContext());
100+
if (val.getDefiningOp<mlir::LLVM::UndefOp>())
101+
return mlir::LLVM::UndefAttr::get(val.getContext());
102+
if (mlir::Operation *op = val.getDefiningOp()) {
103+
unsigned resNum = llvm::cast<mlir::OpResult>(val).getResultNumber();
104+
llvm::SmallVector<mlir::Value> results;
105+
if (mlir::succeeded(rewriter.tryFold(op, results)) &&
106+
results.size() > resNum) {
107+
if (auto cst = results[resNum].getDefiningOp<mlir::LLVM::ConstantOp>())
108+
return cst.getValue();
109+
}
110+
}
111+
if (auto trunc = val.getDefiningOp<mlir::LLVM::TruncOp>())
112+
if (auto attr = getAttrIfConstant(trunc.getArg(), rewriter))
113+
if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(attr))
114+
return mlir::IntegerAttr::get(trunc.getType(), intAttr.getInt());
115+
LLVM_DEBUG(llvm::dbgs() << "cannot fold insert value operand: " << val
116+
<< "\n");
117+
return nullptr;
118+
}
119+
120+
mlir::Attribute
121+
InsertChainBackwardFolder::finalize(mlir::Attribute defaultFieldValue) {
122+
llvm::SmallVector<mlir::Attribute> attrs = llvm::map_to_vector(
123+
values, [&](InFlightValue inFlight) -> mlir::Attribute {
124+
if (!inFlight)
125+
return defaultFieldValue;
126+
if (auto attr = llvm::dyn_cast<mlir::Attribute>(inFlight))
127+
return attr;
128+
return llvm::cast<InsertChainBackwardFolder *>(inFlight)->finalize(
129+
defaultFieldValue);
130+
});
131+
return mlir::ArrayAttr::get(type.getContext(), attrs);
132+
}
133+
134+
bool InsertChainBackwardFolder::pushValue(mlir::Attribute val,
135+
llvm::ArrayRef<int64_t> at) {
136+
if (at.size() == 0 || at[0] >= static_cast<int64_t>(values.size()))
137+
return false;
138+
InFlightValue &inFlight = values[at[0]];
139+
if (!inFlight) {
140+
if (at.size() == 1) {
141+
inFlight = val;
142+
return true;
143+
}
144+
// This is the first insert to a nested field. Create a
145+
// InsertChainBackwardFolder for the current element value.
146+
mlir::Type subType = getSubElementType(type, at[0]);
147+
if (!subType)
148+
return false;
149+
InsertChainBackwardFolder &inFlightList =
150+
folderStorage->emplace_back(subType, folderStorage);
151+
inFlight = &inFlightList;
152+
return inFlightList.pushValue(val, at.drop_front());
153+
}
154+
// Keep last inserted value if already set.
155+
if (llvm::isa<mlir::Attribute>(inFlight))
156+
return true;
157+
auto *inFlightList = llvm::cast<InsertChainBackwardFolder *>(inFlight);
158+
if (at.size() == 1) {
159+
if (!llvm::isa<mlir::LLVM::ZeroAttr, mlir::LLVM::UndefAttr>(val)) {
160+
LLVM_DEBUG(llvm::dbgs()
161+
<< "insert chain sub-element partially overwritten initial "
162+
"value is not zero or undef\n");
163+
return false;
164+
}
165+
inFlight = inFlightList->finalize(val);
166+
return true;
167+
}
168+
return inFlightList->pushValue(val, at.drop_front());
169+
}
170+
171+
llvm::FailureOr<mlir::Attribute>
172+
fir::tryFoldingLLVMInsertChain(mlir::Value val, mlir::OpBuilder &rewriter) {
173+
if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>())
174+
return cst.getValue();
175+
if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) {
176+
LLVM_DEBUG(llvm::dbgs() << "trying to fold insert chain:" << val << "\n");
177+
if (auto structTy =
178+
llvm::dyn_cast<mlir::LLVM::LLVMStructType>(insert.getType())) {
179+
mlir::LLVM::InsertValueOp currentInsert = insert;
180+
mlir::LLVM::InsertValueOp lastInsert;
181+
std::deque<InsertChainBackwardFolder> folderStorage;
182+
InsertChainBackwardFolder inFlightList(structTy, &folderStorage);
183+
while (currentInsert) {
184+
mlir::Attribute attr =
185+
getAttrIfConstant(currentInsert.getValue(), rewriter);
186+
if (!attr)
187+
return llvm::failure();
188+
if (!inFlightList.pushValue(attr, currentInsert.getPosition()))
189+
return llvm::failure();
190+
lastInsert = currentInsert;
191+
currentInsert = currentInsert.getContainer()
192+
.getDefiningOp<mlir::LLVM::InsertValueOp>();
193+
}
194+
mlir::Attribute defaultVal;
195+
if (lastInsert) {
196+
if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::ZeroOp>())
197+
defaultVal = mlir::LLVM::ZeroAttr::get(val.getContext());
198+
else if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::UndefOp>())
199+
defaultVal = mlir::LLVM::UndefAttr::get(val.getContext());
200+
}
201+
if (!defaultVal) {
202+
LLVM_DEBUG(llvm::dbgs()
203+
<< "insert chain initial value is not Zero or Undef\n");
204+
return llvm::failure();
205+
}
206+
return inFlightList.finalize(defaultVal);
207+
}
208+
}
209+
return llvm::failure();
210+
}

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2365,6 +2365,21 @@ llvm::LogicalResult fir::InsertOnRangeOp::verify() {
23652365
return mlir::success();
23662366
}
23672367

2368+
bool fir::InsertOnRangeOp::isFullRange() {
2369+
auto extents = getType().getShape();
2370+
mlir::DenseIntElementsAttr indexes = getCoor();
2371+
if (indexes.size() / 2 != static_cast<int64_t>(extents.size()))
2372+
return false;
2373+
auto cur_index = indexes.value_begin<int64_t>();
2374+
for (unsigned i = 0; i < indexes.size(); i += 2) {
2375+
if (*(cur_index++) != 0)
2376+
return false;
2377+
if (*(cur_index++) != extents[i / 2] - 1)
2378+
return false;
2379+
}
2380+
return true;
2381+
}
2382+
23682383
//===----------------------------------------------------------------------===//
23692384
// InsertValueOp
23702385
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)