Skip to content

[MLIR] Make SymbolTableCollection methods virtual #141760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions mlir/include/mlir/IR/SymbolTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,14 @@ raw_ostream &operator<<(raw_ostream &os, SymbolTable::Visibility visibility);
/// unnecessary tables.
class SymbolTableCollection {
public:
virtual ~SymbolTableCollection() = default;

/// Look up a symbol with the specified name within the specified symbol table
/// operation, returning null if no such name exists.
Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol);
Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
virtual Operation *lookupSymbolIn(Operation *symbolTableOp,
StringAttr symbol);
virtual Operation *lookupSymbolIn(Operation *symbolTableOp,
SymbolRefAttr name);
template <typename T, typename NameT>
T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) {
return dyn_cast_or_null<T>(
Expand All @@ -295,15 +299,18 @@ class SymbolTableCollection {
/// by a given SymbolRefAttr when resolved within the provided symbol table
/// operation. Returns failure if any of the nested references could not be
/// resolved.
LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
SmallVectorImpl<Operation *> &symbols);
virtual LogicalResult lookupSymbolIn(Operation *symbolTableOp,
SymbolRefAttr name,
SmallVectorImpl<Operation *> &symbols);

/// Returns the operation registered with the given symbol name within the
/// closest parent operation of, or including, 'from' with the
/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
/// found.
Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol);
Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol);
virtual Operation *lookupNearestSymbolFrom(Operation *from,
StringAttr symbol);
virtual Operation *lookupNearestSymbolFrom(Operation *from,
SymbolRefAttr symbol);
template <typename T>
T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) {
return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
Expand All @@ -314,14 +321,14 @@ class SymbolTableCollection {
}

/// Lookup, or create, a symbol table for an operation.
SymbolTable &getSymbolTable(Operation *op);
virtual SymbolTable &getSymbolTable(Operation *op);

/// Invalidate the cached symbol table for an operation.
/// This is important when doing IR modifications that erase and also create
/// operations having the 'OpTrait::SymbolTable' trait. If a symbol table of
/// an erased operation is not invalidated, a new operation sharing the same
/// address would be associated with outdated, and wrong, information.
void invalidateSymbolTable(Operation *op);
virtual void invalidateSymbolTable(Operation *op);

private:
friend class LockedSymbolTableCollection;
Expand All @@ -348,13 +355,15 @@ class LockedSymbolTableCollection : public SymbolTableCollection {

/// Look up a symbol with the specified name within the specified symbol table
/// operation, returning null if no such name exists.
Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol);
Operation *lookupSymbolIn(Operation *symbolTableOp,
StringAttr symbol) override;
/// Look up a symbol with the specified name within the specified symbol table
/// operation, returning null if no such name exists.
Operation *lookupSymbolIn(Operation *symbolTableOp, FlatSymbolRefAttr symbol);
/// Look up a potentially nested symbol within the specified symbol table
/// operation, returning null if no such symbol exists.
Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
Operation *lookupSymbolIn(Operation *symbolTableOp,
SymbolRefAttr name) override;

/// Lookup a symbol of a particular kind within the specified symbol table,
/// returning null if the symbol was not found.
Expand All @@ -369,14 +378,14 @@ class LockedSymbolTableCollection : public SymbolTableCollection {
/// operation. Returns failure if any of the nested references could not be
/// resolved.
LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
SmallVectorImpl<Operation *> &symbols);
SmallVectorImpl<Operation *> &symbols) override;

private:
/// Get the symbol table for the symbol table operation, constructing if it
/// does not exist. This function provides thread safety over `collection`
/// by locking when performing the lookup and when inserting
/// lazily-constructed symbol tables.
SymbolTable &getSymbolTable(Operation *symbolTableOp);
SymbolTable &getSymbolTable(Operation *symbolTableOp) override;

/// The symbol tables to manage.
SymbolTableCollection &collection;
Expand Down
Loading