Skip to content

Commit dfa90b5

Browse files
committed
[mlir] Walk nested tables in symbol dce
The previous positioning was effectively that a symbol is dead if it cannot be addressed from top level. I think that is too strong a requirement: one can have operations that one cannot delete/DCE that refers to symbols which one could delete. This resulted in symbol-dce deleting symbols that are still referenced and the resulting IR being invalid. This instead all the symbols of top level operations of non-symbol table ops additionally, as those are either dead and DCE would have handled, or alive and we cannot just delete symbols referenced internally. E.g., this treats non-symbol table regioned ops more conservatively.
1 parent cf9546b commit dfa90b5

File tree

2 files changed

+62
-5
lines changed

2 files changed

+62
-5
lines changed

mlir/lib/Transforms/SymbolDCE.cpp

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Transforms/Passes.h"
1515

1616
#include "mlir/IR/SymbolTable.h"
17+
#include "llvm/Support/Debug.h"
1718

1819
namespace mlir {
1920
#define GEN_PASS_DEF_SYMBOLDCE
@@ -22,6 +23,8 @@ namespace mlir {
2223

2324
using namespace mlir;
2425

26+
#define DEBUG_TYPE "symbol-dce"
27+
2528
namespace {
2629
struct SymbolDCE : public impl::SymbolDCEBase<SymbolDCE> {
2730
void runOnOperation() override;
@@ -84,6 +87,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
8487
SymbolTableCollection &symbolTable,
8588
bool symbolTableIsHidden,
8689
DenseSet<Operation *> &liveSymbols) {
90+
LLVM_DEBUG(llvm::dbgs() << "computeLiveness: " << symbolTableOp->getName()
91+
<< "\n");
8792
// A worklist of live operations to propagate uses from.
8893
SmallVector<Operation *, 16> worklist;
8994

@@ -105,36 +110,69 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
105110
}
106111

107112
// Process the set of symbols that were known to be live, adding new symbols
108-
// that are referenced within.
113+
// that are referenced within. For operations that are not symbol tables, it
114+
// considers the liveness with respect to the op itself rather than scope of
115+
// nested symbol tables by enqueuing all the top level operations for
116+
// consideration.
109117
while (!worklist.empty()) {
110118
Operation *op = worklist.pop_back_val();
119+
LLVM_DEBUG(llvm::dbgs() << "processing: " << op->getName() << "\n");
111120

112121
// If this is a symbol table, recursively compute its liveness.
113122
if (op->hasTrait<OpTrait::SymbolTable>()) {
114123
// The internal symbol table is hidden if the parent is, if its not a
115124
// symbol, or if it is a private symbol.
116125
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
117126
bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
127+
LLVM_DEBUG(llvm::dbgs() << "\tsymbol table: " << op->getName()
128+
<< " is hidden: " << symIsHidden << "\n");
118129
if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
119130
return failure();
131+
} else {
132+
LLVM_DEBUG(llvm::dbgs()
133+
<< "\tnon-symbol table: " << op->getName() << "\n");
134+
// If the op is not a symbol table, then, unless op itself is dead which
135+
// would be handled by DCE, we need to check all the regions and blocks
136+
// within the op to find the uses (e.g., consider visibility within op as
137+
// if top level rather than relying on pure symbol table visibility). This
138+
// is more conservative than SymbolTable::walkSymbolTables in the case
139+
// where there is again SymbolTable information to take advantage of.
140+
for (auto &region : op->getRegions())
141+
for (auto &block : region.getBlocks())
142+
for (Operation &op : block)
143+
worklist.push_back(&op);
120144
}
121145

146+
// Get the first parent symbol table op. Note: due to enqueueing of
147+
// top-level ops, we may not have a symbol table parent here, but if we do
148+
// not, then we also don't have a symbol.
149+
Operation *parentOp = op->getParentOp();
150+
if (!parentOp->hasTrait<OpTrait::SymbolTable>())
151+
continue;
152+
122153
// Collect the uses held by this operation.
123154
std::optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(op);
124155
if (!uses) {
125156
return op->emitError()
126-
<< "operation contains potentially unknown symbol table, "
127-
"meaning that we can't reliable compute symbol uses";
157+
<< "operation contains potentially unknown symbol table, meaning "
158+
<< "that we can't reliable compute symbol uses";
128159
}
129160

130161
SmallVector<Operation *, 4> resolvedSymbols;
162+
LLVM_DEBUG(llvm::dbgs() << "uses of " << op->getName() << "\n");
131163
for (const SymbolTable::SymbolUse &use : *uses) {
164+
LLVM_DEBUG(llvm::dbgs() << "\tuse: " << use.getUser() << "\n");
132165
// Lookup the symbols referenced by this use.
133166
resolvedSymbols.clear();
134-
if (failed(symbolTable.lookupSymbolIn(
135-
op->getParentOp(), use.getSymbolRef(), resolvedSymbols)))
167+
if (failed(symbolTable.lookupSymbolIn(parentOp, use.getSymbolRef(),
168+
resolvedSymbols)))
136169
// Ignore references to unknown symbols.
137170
continue;
171+
LLVM_DEBUG({
172+
llvm::dbgs() << "\t\tresolved symbols: ";
173+
llvm::interleaveComma(resolvedSymbols, llvm::dbgs());
174+
llvm::dbgs() << "\n";
175+
});
138176

139177
// Mark each of the resolved symbols as live.
140178
for (Operation *resolvedSymbol : resolvedSymbols)

mlir/test/Transforms/test-symbol-dce.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,22 @@ module {
9898
// CHECK: "live.user"() {uses = [@unknown_symbol]} : () -> ()
9999
"live.user"() {uses = [@unknown_symbol]} : () -> ()
100100
}
101+
102+
// -----
103+
104+
// Check that we don't DCE nested symbols if they are used even if nested inside
105+
// an unnamed region.
106+
// CHECK-LABEL: module attributes {test.nested_unnamed_region}
107+
module attributes {test.nested_unnamed_region} {
108+
"test.one_region_op"() ({
109+
"test.symbol_scope"() ({
110+
// CHECK: func @nfunction
111+
func.func @nfunction() {
112+
return
113+
}
114+
func.call @nfunction() : () -> ()
115+
"test.finish"() : () -> ()
116+
}) : () -> ()
117+
"test.finish"() : () -> ()
118+
}) : () -> ()
119+
}

0 commit comments

Comments
 (0)