Skip to content

Commit 490fc19

Browse files
author
git apple-llvm automerger
committed
Merge commit '7bc7d0ac7ae2' from llvm.org/master into apple/main
2 parents c0f5473 + 7bc7d0a commit 490fc19

File tree

5 files changed

+101
-12
lines changed

5 files changed

+101
-12
lines changed

mlir/include/mlir/IR/OpDefinition.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1536,7 +1536,7 @@ class OpInterface
15361536
/// Inherit the base class constructor.
15371537
using InterfaceBase::InterfaceBase;
15381538

1539-
private:
1539+
protected:
15401540
/// Returns the impl interface instance for the given operation.
15411541
static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
15421542
// Access the raw interface from the abstract operation.

mlir/include/mlir/IR/SymbolInterfaces.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,11 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
145145
let extraClassDeclaration = [{
146146
/// Custom classof that handles the case where the symbol is optional.
147147
static bool classof(Operation *op) {
148-
return Base::classof(op)
149-
&& op->getAttr(::mlir::SymbolTable::getSymbolAttrName());
148+
auto *concept = getInterfaceFor(op);
149+
if (!concept)
150+
return false;
151+
return !concept->isOptionalSymbol(op) ||
152+
op->getAttr(::mlir::SymbolTable::getSymbolAttrName());
150153
}
151154
}];
152155

mlir/include/mlir/IR/SymbolTable.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,40 @@ class SymbolTable {
210210
unsigned uniquingCounter = 0;
211211
};
212212

213+
//===----------------------------------------------------------------------===//
214+
// SymbolTableCollection
215+
//===----------------------------------------------------------------------===//
216+
217+
/// This class represents a collection of `SymbolTable`s. This simplifies
218+
/// certain algorithms that run recursively on nested symbol tables. Symbol
219+
/// tables are constructed lazily to reduce the upfront cost of constructing
220+
/// unnecessary tables.
221+
class SymbolTableCollection {
222+
public:
223+
/// Look up a symbol with the specified name within the specified symbol table
224+
/// operation, returning null if no such name exists.
225+
Operation *lookupSymbolIn(Operation *symbolTableOp, StringRef symbol);
226+
Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
227+
template <typename T, typename NameT>
228+
T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) const {
229+
return dyn_cast_or_null<T>(
230+
lookupSymbolIn(symbolTableOp, std::forward<NameT>(name)));
231+
}
232+
/// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
233+
/// by a given SymbolRefAttr when resolved within the provided symbol table
234+
/// operation. Returns failure if any of the nested references could not be
235+
/// resolved.
236+
LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
237+
SmallVectorImpl<Operation *> &symbols);
238+
239+
/// Lookup, or create, a symbol table for an operation.
240+
SymbolTable &getSymbolTable(Operation *op);
241+
242+
private:
243+
/// The constructed symbol tables nested within this table.
244+
DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables;
245+
};
246+
213247
//===----------------------------------------------------------------------===//
214248
// SymbolTable Trait Types
215249
//===----------------------------------------------------------------------===//

mlir/lib/IR/SymbolTable.cpp

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,16 @@ Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
258258
return resolvedSymbols.back();
259259
}
260260

261-
LogicalResult
262-
SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
263-
SmallVectorImpl<Operation *> &symbols) {
261+
/// Internal implementation of `lookupSymbolIn` that allows for specialized
262+
/// implementations of the lookup function.
263+
static LogicalResult lookupSymbolInImpl(
264+
Operation *symbolTableOp, SymbolRefAttr symbol,
265+
SmallVectorImpl<Operation *> &symbols,
266+
function_ref<Operation *(Operation *, StringRef)> lookupSymbolFn) {
264267
assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
265268

266269
// Lookup the root reference for this symbol.
267-
symbolTableOp = lookupSymbolIn(symbolTableOp, symbol.getRootReference());
270+
symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
268271
if (!symbolTableOp)
269272
return failure();
270273
symbols.push_back(symbolTableOp);
@@ -281,15 +284,24 @@ SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
281284
// Otherwise, lookup each of the nested non-leaf references and ensure that
282285
// each corresponds to a valid symbol table.
283286
for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
284-
symbolTableOp = lookupSymbolIn(symbolTableOp, ref.getValue());
287+
symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getValue());
285288
if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
286289
return failure();
287290
symbols.push_back(symbolTableOp);
288291
}
289-
symbols.push_back(lookupSymbolIn(symbolTableOp, symbol.getLeafReference()));
292+
symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
290293
return success(symbols.back());
291294
}
292295

296+
LogicalResult
297+
SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
298+
SmallVectorImpl<Operation *> &symbols) {
299+
auto lookupFn = [](Operation *symbolTableOp, StringRef symbol) {
300+
return lookupSymbolIn(symbolTableOp, symbol);
301+
};
302+
return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
303+
}
304+
293305
/// Returns the operation registered with the given symbol name within the
294306
/// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
295307
/// nullptr if no valid symbol was found.
@@ -887,6 +899,42 @@ LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
887899
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
888900
}
889901

902+
//===----------------------------------------------------------------------===//
903+
// SymbolTableCollection
904+
//===----------------------------------------------------------------------===//
905+
906+
Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
907+
StringRef symbol) {
908+
return getSymbolTable(symbolTableOp).lookup(symbol);
909+
}
910+
Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
911+
SymbolRefAttr name) {
912+
SmallVector<Operation *, 4> symbols;
913+
if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
914+
return nullptr;
915+
return symbols.back();
916+
}
917+
/// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
918+
/// a given SymbolRefAttr. Returns failure if any of the nested references could
919+
/// not be resolved.
920+
LogicalResult
921+
SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
922+
SymbolRefAttr name,
923+
SmallVectorImpl<Operation *> &symbols) {
924+
auto lookupFn = [this](Operation *symbolTableOp, StringRef symbol) {
925+
return lookupSymbolIn(symbolTableOp, symbol);
926+
};
927+
return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
928+
}
929+
930+
/// Lookup, or create, a symbol table for an operation.
931+
SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
932+
auto it = symbolTables.try_emplace(op, nullptr);
933+
if (it.second)
934+
it.first->second = std::make_unique<SymbolTable>(op);
935+
return *it.first->second;
936+
}
937+
890938
//===----------------------------------------------------------------------===//
891939
// Symbol Interfaces
892940
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/SymbolDCE.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct SymbolDCE : public SymbolDCEBase<SymbolDCE> {
2424
/// `symbolTableIsHidden` is true if this symbol table is known to be
2525
/// unaccessible from operations in its parent regions.
2626
LogicalResult computeLiveness(Operation *symbolTableOp,
27+
SymbolTableCollection &symbolTable,
2728
bool symbolTableIsHidden,
2829
DenseSet<Operation *> &liveSymbols);
2930
};
@@ -49,7 +50,9 @@ void SymbolDCE::runOnOperation() {
4950

5051
// Compute the set of live symbols within the symbol table.
5152
DenseSet<Operation *> liveSymbols;
52-
if (failed(computeLiveness(symbolTableOp, symbolTableIsHidden, liveSymbols)))
53+
SymbolTableCollection symbolTable;
54+
if (failed(computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden,
55+
liveSymbols)))
5356
return signalPassFailure();
5457

5558
// After computing the liveness, delete all of the symbols that were found to
@@ -71,6 +74,7 @@ void SymbolDCE::runOnOperation() {
7174
/// `symbolTableIsHidden` is true if this symbol table is known to be
7275
/// unaccessible from operations in its parent regions.
7376
LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
77+
SymbolTableCollection &symbolTable,
7478
bool symbolTableIsHidden,
7579
DenseSet<Operation *> &liveSymbols) {
7680
// A worklist of live operations to propagate uses from.
@@ -104,7 +108,7 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
104108
// symbol, or if it is a private symbol.
105109
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
106110
bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
107-
if (failed(computeLiveness(op, symIsHidden, liveSymbols)))
111+
if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
108112
return failure();
109113
}
110114

@@ -120,7 +124,7 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
120124
for (const SymbolTable::SymbolUse &use : *uses) {
121125
// Lookup the symbols referenced by this use.
122126
resolvedSymbols.clear();
123-
if (failed(SymbolTable::lookupSymbolIn(
127+
if (failed(symbolTable.lookupSymbolIn(
124128
op->getParentOp(), use.getSymbolRef(), resolvedSymbols))) {
125129
return use.getUser()->emitError()
126130
<< "unable to resolve reference to symbol "

0 commit comments

Comments
 (0)