@@ -258,13 +258,16 @@ Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
258
258
return resolvedSymbols.back ();
259
259
}
260
260
261
- LogicalResult
262
- SymbolTable::lookupSymbolIn (Operation *symbolTableOp, SymbolRefAttr symbol,
263
- SmallVectorImpl<Operation *> &symbols) {
261
+ // / Internal implementation of `lookupSymbolIn` that allows for specialized
262
+ // / implementations of the lookup function.
263
+ static LogicalResult lookupSymbolInImpl (
264
+ Operation *symbolTableOp, SymbolRefAttr symbol,
265
+ SmallVectorImpl<Operation *> &symbols,
266
+ function_ref<Operation *(Operation *, StringRef)> lookupSymbolFn) {
264
267
assert (symbolTableOp->hasTrait <OpTrait::SymbolTable>());
265
268
266
269
// Lookup the root reference for this symbol.
267
- symbolTableOp = lookupSymbolIn (symbolTableOp, symbol.getRootReference ());
270
+ symbolTableOp = lookupSymbolFn (symbolTableOp, symbol.getRootReference ());
268
271
if (!symbolTableOp)
269
272
return failure ();
270
273
symbols.push_back (symbolTableOp);
@@ -281,15 +284,24 @@ SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
281
284
// Otherwise, lookup each of the nested non-leaf references and ensure that
282
285
// each corresponds to a valid symbol table.
283
286
for (FlatSymbolRefAttr ref : nestedRefs.drop_back ()) {
284
- symbolTableOp = lookupSymbolIn (symbolTableOp, ref.getValue ());
287
+ symbolTableOp = lookupSymbolFn (symbolTableOp, ref.getValue ());
285
288
if (!symbolTableOp || !symbolTableOp->hasTrait <OpTrait::SymbolTable>())
286
289
return failure ();
287
290
symbols.push_back (symbolTableOp);
288
291
}
289
- symbols.push_back (lookupSymbolIn (symbolTableOp, symbol.getLeafReference ()));
292
+ symbols.push_back (lookupSymbolFn (symbolTableOp, symbol.getLeafReference ()));
290
293
return success (symbols.back ());
291
294
}
292
295
296
+ LogicalResult
297
+ SymbolTable::lookupSymbolIn (Operation *symbolTableOp, SymbolRefAttr symbol,
298
+ SmallVectorImpl<Operation *> &symbols) {
299
+ auto lookupFn = [](Operation *symbolTableOp, StringRef symbol) {
300
+ return lookupSymbolIn (symbolTableOp, symbol);
301
+ };
302
+ return lookupSymbolInImpl (symbolTableOp, symbol, symbols, lookupFn);
303
+ }
304
+
293
305
// / Returns the operation registered with the given symbol name within the
294
306
// / closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
295
307
// / nullptr if no valid symbol was found.
@@ -887,6 +899,42 @@ LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
887
899
return replaceAllSymbolUsesImpl (oldSymbol, newSymbol, from);
888
900
}
889
901
902
+ // ===----------------------------------------------------------------------===//
903
+ // SymbolTableCollection
904
+ // ===----------------------------------------------------------------------===//
905
+
906
+ Operation *SymbolTableCollection::lookupSymbolIn (Operation *symbolTableOp,
907
+ StringRef symbol) {
908
+ return getSymbolTable (symbolTableOp).lookup (symbol);
909
+ }
910
+ Operation *SymbolTableCollection::lookupSymbolIn (Operation *symbolTableOp,
911
+ SymbolRefAttr name) {
912
+ SmallVector<Operation *, 4 > symbols;
913
+ if (failed (lookupSymbolIn (symbolTableOp, name, symbols)))
914
+ return nullptr ;
915
+ return symbols.back ();
916
+ }
917
+ // / A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
918
+ // / a given SymbolRefAttr. Returns failure if any of the nested references could
919
+ // / not be resolved.
920
+ LogicalResult
921
+ SymbolTableCollection::lookupSymbolIn (Operation *symbolTableOp,
922
+ SymbolRefAttr name,
923
+ SmallVectorImpl<Operation *> &symbols) {
924
+ auto lookupFn = [this ](Operation *symbolTableOp, StringRef symbol) {
925
+ return lookupSymbolIn (symbolTableOp, symbol);
926
+ };
927
+ return lookupSymbolInImpl (symbolTableOp, name, symbols, lookupFn);
928
+ }
929
+
930
+ // / Lookup, or create, a symbol table for an operation.
931
+ SymbolTable &SymbolTableCollection::getSymbolTable (Operation *op) {
932
+ auto it = symbolTables.try_emplace (op, nullptr );
933
+ if (it.second )
934
+ it.first ->second = std::make_unique<SymbolTable>(op);
935
+ return *it.first ->second ;
936
+ }
937
+
890
938
// ===----------------------------------------------------------------------===//
891
939
// Symbol Interfaces
892
940
// ===----------------------------------------------------------------------===//
0 commit comments