Skip to content

Commit eb68b4f

Browse files
Add SymbolTable::rename.
1 parent 5485313 commit eb68b4f

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

mlir/include/mlir/IR/SymbolTable.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ class SymbolTable {
5555
/// after insertion as attribute.
5656
StringAttr insert(Operation *symbol, Block::iterator insertPt = {});
5757

58+
/// Renames the given op or the op refered to by the given name to the given
59+
/// new name and updates the symbol table and all usages of the symbol
60+
/// accordingly. Fails if the updating of the usages fails.
61+
LogicalResult rename(StringAttr from, StringAttr to);
62+
LogicalResult rename(Operation *op, StringAttr to);
63+
LogicalResult rename(StringAttr from, StringRef to);
64+
LogicalResult rename(Operation *op, StringRef to);
65+
5866
/// Return the name of the attribute used for symbol names.
5967
static StringRef getSymbolAttrName() { return "sym_name"; }
6068

mlir/lib/IR/SymbolTable.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,46 @@ StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
218218
return getSymbolName(symbol);
219219
}
220220

221+
LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) {
222+
Operation *op = lookup(from);
223+
return rename(op, to);
224+
}
225+
226+
LogicalResult SymbolTable::rename(Operation *op, StringAttr to) {
227+
StringAttr from = getNameIfSymbol(op);
228+
229+
assert(from && "expected valid 'name' attribute");
230+
assert(op->getParentOp() == symbolTableOp &&
231+
"expected this operation to be inside of the operation with this "
232+
"SymbolTable");
233+
assert(lookup(from) == op && "current name does not resolve to op");
234+
assert(lookup(to) == nullptr && "new name already exists");
235+
236+
if (failed(SymbolTable::replaceAllSymbolUses(op, to, getOp())))
237+
return failure();
238+
239+
// Remove op with old name, change name, add with new name. The order is
240+
// important here due to how `remove` and `insert` rely on the op name.
241+
remove(op);
242+
setSymbolName(op, to);
243+
insert(op);
244+
245+
assert(lookup(to) == op && "new name does not resolve to renamed op");
246+
assert(lookup(from) == nullptr && "old name still exists");
247+
248+
return success();
249+
}
250+
251+
LogicalResult SymbolTable::rename(StringAttr from, StringRef to) {
252+
auto toAttr = StringAttr::get(getOp()->getContext(), to);
253+
return rename(from, toAttr);
254+
}
255+
256+
LogicalResult SymbolTable::rename(Operation *op, StringRef to) {
257+
auto toAttr = StringAttr::get(getOp()->getContext(), to);
258+
return rename(op, toAttr);
259+
}
260+
221261
/// Returns the name of the given symbol operation.
222262
StringAttr SymbolTable::getSymbolName(Operation *symbol) {
223263
StringAttr name = getNameIfSymbol(symbol);

0 commit comments

Comments
 (0)