Skip to content

Commit 4ad5c20

Browse files
Use new functions in SymbolTable.
1 parent 02e1ef5 commit 4ad5c20

File tree

1 file changed

+21
-56
lines changed

1 file changed

+21
-56
lines changed

mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp

Lines changed: 21 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -304,57 +304,6 @@ static void performOptionalDebugActions(
304304
transform->removeAttr(kTransformDialectTagAttrName);
305305
}
306306

307-
/// Rename `op` to avoid a collision with `otherOp`. `symbolTable` and
308-
/// `otherSymbolTable` are the symbol tables of the two ops, respectively.
309-
/// `uniqueId` is used to generate a unique name in the context of the caller.
310-
LogicalResult renameToUnique(SymbolOpInterface op, SymbolOpInterface otherOp,
311-
SymbolTable &symbolTable,
312-
SymbolTable &otherSymbolTable, int &uniqueId) {
313-
assert(symbolTable.lookup(op.getNameAttr()) == op &&
314-
"symbol table does not contain op");
315-
assert(otherSymbolTable.lookup(otherOp.getNameAttr()) == otherOp &&
316-
"other symbol table does not contain other op");
317-
318-
// Determine new name that is unique in both symbol tables.
319-
StringAttr oldName = op.getNameAttr();
320-
StringAttr newName;
321-
{
322-
MLIRContext *context = op->getContext();
323-
SmallString<64> prefix = oldName.getValue();
324-
prefix.push_back('_');
325-
while (true) {
326-
newName = StringAttr::get(context, prefix + Twine(uniqueId++));
327-
if (!symbolTable.lookup(newName) && !otherSymbolTable.lookup(newName)) {
328-
break;
329-
}
330-
}
331-
}
332-
333-
// Apply renaming.
334-
LLVM_DEBUG(llvm::dbgs() << ", renaming to @" << newName.getValue() << "\n");
335-
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
336-
if (failed(SymbolTable::replaceAllSymbolUses(op, newName, symbolTableOp))) {
337-
InFlightDiagnostic diag =
338-
emitError(op->getLoc(),
339-
Twine("failed to rename symbol to @") + newName.getValue());
340-
diag.attachNote(otherOp->getLoc())
341-
<< "attempted renaming due to collision with this op";
342-
return diag;
343-
}
344-
345-
// Change the symbol in the op itself and update the symbol table.
346-
symbolTable.remove(op);
347-
SymbolTable::setSymbolName(op, newName);
348-
symbolTable.insert(op);
349-
350-
assert(symbolTable.lookup(newName) == op &&
351-
"symbol table does not resolve to renamed op");
352-
assert(symbolTable.lookup(oldName) == nullptr &&
353-
"symbol table still resolves old name");
354-
355-
return success();
356-
}
357-
358307
/// Return whether `func1` can be merged into `func2`.
359308
bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
360309
return func1.isExternal() && (func2.isPublic() || func2.isExternal());
@@ -429,8 +378,6 @@ static LogicalResult mergeSymbolsInto(Operation *target,
429378
SymbolTable targetSymbolTable(target);
430379
SymbolTable otherSymbolTable(*other);
431380

432-
int uniqueId = 0;
433-
434381
// Step 1:
435382
//
436383
// Rename private symbols in both ops in order to resolve conflicts that can
@@ -471,16 +418,34 @@ static LogicalResult mergeSymbolsInto(Operation *target,
471418
LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
472419
}
473420

474-
// Collision can be resolved if one of the ops is private.
421+
// Collision can be resolved by renaming if one of the ops is private.
422+
auto renameToUnique =
423+
[&](SymbolOpInterface op, SymbolOpInterface otherOp,
424+
SymbolTable &symbolTable,
425+
SymbolTable &otherSymbolTable) -> LogicalResult {
426+
LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
427+
FailureOr<StringAttr> maybeNewName =
428+
symbolTable.renameToUnique(op, {&otherSymbolTable});
429+
if (failed(maybeNewName)) {
430+
InFlightDiagnostic diag = op->emitError("failed to rename symbol");
431+
diag.attachNote(otherOp->getLoc())
432+
<< "attempted renaming due to collision with this op";
433+
return diag;
434+
}
435+
LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
436+
<< "\n");
437+
return success();
438+
};
439+
475440
if (symbolOp.isPrivate()) {
476441
if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
477-
*otherSymbolTable, uniqueId)))
442+
*otherSymbolTable)))
478443
return failure();
479444
continue;
480445
}
481446
if (collidingOp.isPrivate()) {
482447
if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
483-
*symbolTable, uniqueId)))
448+
*symbolTable)))
484449
return failure();
485450
continue;
486451
}

0 commit comments

Comments
 (0)