Skip to content

Commit f98a447

Browse files
jpienaarjoker-eph
authored andcommitted
[mlir] Walk nested non-symbol table ops in symbol dce (llvm#143353)
The previous code was effectively that a symbol is dead if was not nested in sequence of SymbolTables. But one can have operations that one cannot delete/DCE that refers to symbols which one could delete which resulted in symbol-dce deleting symbols that are still referenced and the resulting IR being invalid. This changes it so that all operations inside non SymbolTable op are considered to find nested SymbolTable ops. --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent 540fa0b commit f98a447

File tree

2 files changed

+79
-5
lines changed

2 files changed

+79
-5
lines changed

mlir/lib/Transforms/SymbolDCE.cpp

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Transforms/Passes.h"
1515

1616
#include "mlir/IR/SymbolTable.h"
17+
#include "llvm/Support/Debug.h"
1718

1819
namespace mlir {
1920
#define GEN_PASS_DEF_SYMBOLDCE
@@ -22,6 +23,8 @@ namespace mlir {
2223

2324
using namespace mlir;
2425

26+
#define DEBUG_TYPE "symbol-dce"
27+
2528
namespace {
2629
struct SymbolDCE : public impl::SymbolDCEBase<SymbolDCE> {
2730
void runOnOperation() override;
@@ -84,6 +87,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
8487
SymbolTableCollection &symbolTable,
8588
bool symbolTableIsHidden,
8689
DenseSet<Operation *> &liveSymbols) {
90+
LLVM_DEBUG(llvm::dbgs() << "computeLiveness: " << symbolTableOp->getName()
91+
<< "\n");
8792
// A worklist of live operations to propagate uses from.
8893
SmallVector<Operation *, 16> worklist;
8994

@@ -105,36 +110,70 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
105110
}
106111

107112
// Process the set of symbols that were known to be live, adding new symbols
108-
// that are referenced within.
113+
// that are referenced within. For operations that are not symbol tables, it
114+
// considers the liveness with respect to the op itself rather than scope of
115+
// nested symbol tables by enqueuing all the top level operations for
116+
// consideration.
109117
while (!worklist.empty()) {
110118
Operation *op = worklist.pop_back_val();
119+
LLVM_DEBUG(llvm::dbgs() << "processing: " << op->getName() << "\n");
111120

112121
// If this is a symbol table, recursively compute its liveness.
113122
if (op->hasTrait<OpTrait::SymbolTable>()) {
114123
// The internal symbol table is hidden if the parent is, if its not a
115124
// symbol, or if it is a private symbol.
116125
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
117126
bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
127+
LLVM_DEBUG(llvm::dbgs() << "\tsymbol table: " << op->getName()
128+
<< " is hidden: " << symIsHidden << "\n");
118129
if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
119130
return failure();
131+
} else {
132+
LLVM_DEBUG(llvm::dbgs()
133+
<< "\tnon-symbol table: " << op->getName() << "\n");
134+
// If the op is not a symbol table, then, unless op itself is dead which
135+
// would be handled by DCE, we need to check all the regions and blocks
136+
// within the op to find the uses (e.g., consider visibility within op as
137+
// if top level rather than relying on pure symbol table visibility). This
138+
// is more conservative than SymbolTable::walkSymbolTables in the case
139+
// where there is again SymbolTable information to take advantage of.
140+
for (auto &region : op->getRegions())
141+
for (auto &block : region.getBlocks())
142+
for (Operation &op : block)
143+
if (op.getNumRegions())
144+
worklist.push_back(&op);
120145
}
121146

147+
// Get the first parent symbol table op. Note: due to enqueueing of
148+
// top-level ops, we may not have a symbol table parent here, but if we do
149+
// not, then we also don't have a symbol.
150+
Operation *parentOp = op->getParentOp();
151+
if (!parentOp->hasTrait<OpTrait::SymbolTable>())
152+
continue;
153+
122154
// Collect the uses held by this operation.
123155
std::optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(op);
124156
if (!uses) {
125157
return op->emitError()
126-
<< "operation contains potentially unknown symbol table, "
127-
"meaning that we can't reliable compute symbol uses";
158+
<< "operation contains potentially unknown symbol table, meaning "
159+
<< "that we can't reliable compute symbol uses";
128160
}
129161

130162
SmallVector<Operation *, 4> resolvedSymbols;
163+
LLVM_DEBUG(llvm::dbgs() << "uses of " << op->getName() << "\n");
131164
for (const SymbolTable::SymbolUse &use : *uses) {
165+
LLVM_DEBUG(llvm::dbgs() << "\tuse: " << use.getUser() << "\n");
132166
// Lookup the symbols referenced by this use.
133167
resolvedSymbols.clear();
134-
if (failed(symbolTable.lookupSymbolIn(
135-
op->getParentOp(), use.getSymbolRef(), resolvedSymbols)))
168+
if (failed(symbolTable.lookupSymbolIn(parentOp, use.getSymbolRef(),
169+
resolvedSymbols)))
136170
// Ignore references to unknown symbols.
137171
continue;
172+
LLVM_DEBUG({
173+
llvm::dbgs() << "\t\tresolved symbols: ";
174+
llvm::interleaveComma(resolvedSymbols, llvm::dbgs());
175+
llvm::dbgs() << "\n";
176+
});
138177

139178
// Mark each of the resolved symbols as live.
140179
for (Operation *resolvedSymbol : resolvedSymbols)

mlir/test/Transforms/test-symbol-dce.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,38 @@ module {
9898
// CHECK: "live.user"() {uses = [@unknown_symbol]} : () -> ()
9999
"live.user"() {uses = [@unknown_symbol]} : () -> ()
100100
}
101+
102+
// -----
103+
104+
// Check that we don't DCE nested symbols if they are nested inside region
105+
// without SymbolTable.
106+
107+
// CHECK-LABEL: module attributes {test.nested_nosymboltable_region}
108+
module attributes { test.nested_nosymboltable_region } {
109+
"test.one_region_op"() ({
110+
"test.symbol_scope"() ({
111+
// CHECK: func nested @nfunction
112+
func.func nested @nfunction() {
113+
return
114+
}
115+
func.call @nfunction() : () -> ()
116+
"test.finish"() : () -> ()
117+
}) : () -> ()
118+
"test.finish"() : () -> ()
119+
}) : () -> ()
120+
}
121+
122+
// -----
123+
124+
// CHECK-LABEL: module attributes {test.nested_nosymboltable_region_notcalled}
125+
// CHECK-NOT: @nested
126+
// CHECK: @main
127+
module attributes { test.nested_nosymboltable_region_notcalled } {
128+
"test.one_region_op"() ({
129+
module {
130+
func.func nested @nested() { return }
131+
func.func @main() { return }
132+
}
133+
"test.finish"() : () -> ()
134+
}) : () -> ()
135+
}

0 commit comments

Comments
 (0)