Skip to content

[MLIR] Add optional cached symbol tables to LLVM conversion patterns #144032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 21, 2025

Conversation

mscuttari
Copy link
Member

This PR allows to optionally speed up the lookup of symbols by providing a SymbolTableCollection instance to the interested conversion patterns.
It is follow-up on the discussion about symbol / symbol table management carried on Discourse.

@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2025

@llvm/pr-subscribers-mlir

Author: Michele Scuttari (mscuttari)

Changes

This PR allows to optionally speed up the lookup of symbols by providing a SymbolTableCollection instance to the interested conversion patterns.
It is follow-up on the discussion about symbol / symbol table management carried on Discourse.


Patch is 52.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144032.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h (+4-3)
  • (modified) mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h (+6-3)
  • (modified) mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h (+3-1)
  • (modified) mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h (+3-1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h (+51-32)
  • (modified) mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp (+8-5)
  • (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+44-28)
  • (modified) mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp (+19-4)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+102-38)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+24-12)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp (+89-49)
diff --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
index 88f18022da9bb..2dfb6b03bcfcd 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
@@ -20,6 +20,7 @@ class DialectRegistry;
 class LLVMTypeConverter;
 class RewritePatternSet;
 class Pass;
+class SymbolTableCollection;
 
 #define GEN_PASS_DECL_CONVERTCONTROLFLOWTOLLVMPASS
 #include "mlir/Conversion/Passes.h.inc"
@@ -39,9 +40,9 @@ void populateControlFlowToLLVMConversionPatterns(
 /// Populate the cf.assert to LLVM conversion pattern. If `abortOnFailure` is
 /// set to false, the program execution continues when a condition is
 /// unsatisfied.
-void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter,
-                                           RewritePatternSet &patterns,
-                                           bool abortOnFailure = true);
+void populateAssertToLLVMConversionPattern(
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    bool abortOnFailure = true, SymbolTableCollection *symbolTables = nullptr);
 
 void registerConvertControlFlowToLLVMInterface(DialectRegistry &registry);
 
diff --git a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
index b1ea2740c0605..e530b0a43b8e0 100644
--- a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
+++ b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
@@ -27,20 +27,23 @@ class DialectRegistry;
 class LLVMTypeConverter;
 class RewritePatternSet;
 class SymbolTable;
+class SymbolTableCollection;
 
 /// Convert input FunctionOpInterface operation to LLVMFuncOp by using the
 /// provided LLVMTypeConverter. Return failure if failed to so.
 FailureOr<LLVM::LLVMFuncOp>
 convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                           ConversionPatternRewriter &rewriter,
-                          const LLVMTypeConverter &converter);
+                          const LLVMTypeConverter &converter,
+                          SymbolTableCollection *symbolTables = nullptr);
 
 /// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
 /// `emitCWrappers` is set, the pattern will also produce functions
 /// that pass memref descriptors by pointer-to-structure in addition to the
 /// default unpacked form.
 void populateFuncToLLVMFuncOpConversionPattern(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    SymbolTableCollection *symbolTables = nullptr);
 
 /// Collect the patterns to convert from the Func dialect to LLVM. The
 /// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions
@@ -57,7 +60,7 @@ void populateFuncToLLVMFuncOpConversionPattern(
 /// not an error to provide it anyway.
 void populateFuncToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    const SymbolTable *symbolTable = nullptr);
+    SymbolTableCollection *symbolTables = nullptr);
 
 void registerConvertFuncToLLVMInterface(DialectRegistry &registry);
 
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
index 33402301115b7..d7de40555bb6a 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -17,6 +17,7 @@ namespace mlir {
 
 class OpBuilder;
 class LLVMTypeConverter;
+class SymbolTableCollection;
 
 namespace LLVM {
 
@@ -26,7 +27,8 @@ namespace LLVM {
 LogicalResult createPrintStrCall(
     OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
     StringRef string, const LLVMTypeConverter &typeConverter,
-    bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {});
+    bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {},
+    SymbolTableCollection *symbolTables = nullptr);
 } // namespace LLVM
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
index 996a64baf9dd5..e93d5bdce7bf2 100644
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
+++ b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
@@ -16,6 +16,7 @@ class DialectRegistry;
 class Pass;
 class LLVMTypeConverter;
 class RewritePatternSet;
+class SymbolTableCollection;
 
 #define GEN_PASS_DECL_FINALIZEMEMREFTOLLVMCONVERSIONPASS
 #include "mlir/Conversion/Passes.h.inc"
@@ -23,7 +24,8 @@ class RewritePatternSet;
 /// Collect a set of patterns to convert memory-related operations from the
 /// MemRef dialect to the LLVM dialect.
 void populateFinalizeMemRefToLLVMConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    SymbolTableCollection *symbolTables = nullptr);
 
 void registerConvertMemRefToLLVMInterface(DialectRegistry &registry);
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 4a7ec6f2efe64..8ad9ed18acebd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -24,6 +24,7 @@ class OpBuilder;
 class Operation;
 class Type;
 class ValueRange;
+class SymbolTableCollection;
 
 namespace LLVM {
 class LLVMFuncOp;
@@ -33,55 +34,73 @@ class LLVMFuncOp;
 /// implemented separately (e.g. as part of a support runtime library or as part
 /// of the libc).
 /// Failure if an unexpected version of function is found.
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(OpBuilder &b,
-                                                     Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(OpBuilder &b,
-                                                     Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(OpBuilder &b,
-                                                     Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(OpBuilder &b,
-                                                      Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(OpBuilder &b,
-                                                     Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(OpBuilder &b,
-                                                     Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp,
+                         SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp,
+                         SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp,
+                         SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp,
+                          SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
+                         SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
+                         SymbolTableCollection *symbolTables = nullptr);
 /// Declares a function to print a C-string.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
 /// have the signature void(char const*). The default function is `printString`.
 FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreatePrintStringFn(OpBuilder &b, Operation *moduleOp,
-                            std::optional<StringRef> runtimeFunctionName = {});
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(OpBuilder &b,
-                                                      Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(OpBuilder &b,
-                                                       Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(OpBuilder &b,
-                                                       Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(OpBuilder &b,
-                                                         Operation *moduleOp);
+                            std::optional<StringRef> runtimeFunctionName = {},
+                            SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp,
+                          SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp,
+                           SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp,
+                           SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp,
+                             SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
+                       SymbolTableCollection *symbolTables = nullptr);
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
+lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
+                             SymbolTableCollection *symbolTables = nullptr);
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(OpBuilder &b,
-                                                 Operation *moduleOp);
+lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp,
+                     SymbolTableCollection *symbolTables = nullptr);
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
+lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
+                             SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAlignedAllocFn(
+    OpBuilder &b, Operation *moduleOp, Type indexType,
+    SymbolTableCollection *symbolTables = nullptr);
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp,
-                                    Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(OpBuilder &b,
-                                                        Operation *moduleOp);
+lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp,
+                            SymbolTableCollection *symbolTables = nullptr);
 FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType,
-                           Type unrankedDescriptorType);
+                           Type unrankedDescriptorType,
+                           SymbolTableCollection *symbolTables = nullptr);
 
 /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
 /// Return a failure if the FuncOp found has unexpected signature.
 FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
                  ArrayRef<Type> paramTypes = {}, Type resultType = {},
-                 bool isVarArg = false, bool isReserved = false);
+                 bool isVarArg = false, bool isReserved = false,
+                 SymbolTableCollection *symbolTables = nullptr);
 
 } // namespace LLVM
 } // namespace mlir
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index debfd003bd5b5..eaa8e7d26d4bd 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -44,9 +44,10 @@ namespace {
 /// lowering.
 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
   explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
-                            bool abortOnFailedAssert = true)
+                            bool abortOnFailedAssert = true,
+                            SymbolTableCollection *symbolTables = nullptr)
       : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
-        abortOnFailedAssert(abortOnFailedAssert) {}
+        abortOnFailedAssert(abortOnFailedAssert), symbolTables(symbolTables) {}
 
   LogicalResult
   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
@@ -64,7 +65,7 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
     auto createResult = LLVM::createPrintStrCall(
         rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
         /*addNewLine=*/false,
-        /*runtimeFunctionName=*/"puts");
+        /*runtimeFunctionName=*/"puts", symbolTables);
     if (createResult.failed())
       return failure();
 
@@ -96,6 +97,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
   /// If set to `false`, messages are printed but program execution continues.
   /// This is useful for testing asserts.
   bool abortOnFailedAssert = true;
+
+  SymbolTableCollection *symbolTables = nullptr;
 };
 
 /// Helper function for converting branch ops. This function converts the
@@ -227,8 +230,8 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
 
 void mlir::cf::populateAssertToLLVMConversionPattern(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    bool abortOnFailure) {
-  patterns.add<AssertOpLowering>(converter, abortOnFailure);
+    bool abortOnFailure, SymbolTableCollection *symbolTables) {
+  patterns.add<AssertOpLowering>(converter, abortOnFailure, symbolTables);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 328c605add65c..6a6371921c1d5 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -299,10 +299,9 @@ static void restoreByValRefArgumentType(
   }
 }
 
-FailureOr<LLVM::LLVMFuncOp>
-mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
-                                ConversionPatternRewriter &rewriter,
-                                const LLVMTypeConverter &converter) {
+FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
+    FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter,
+    const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables) {
   // Check the funcOp has `FunctionType`.
   auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
   if (!funcTy)
@@ -365,10 +364,25 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
 
   SmallVector<NamedAttribute, 4> attributes;
   filterFuncAttributes(funcOp, attributes);
+
+  Operation *symbolTableOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
+
+  if (symbolTables && symbolTableOp) {
+    SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+    symbolTable.remove(funcOp);
+  }
+
   auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
       funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
       /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
       attributes);
+
+  if (symbolTables && symbolTableOp) {
+    auto ip = rewriter.getInsertionPoint();
+    SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+    symbolTable.insert(newFuncOp, ip);
+  }
+
   cast<FunctionOpInterface>(newFuncOp.getOperation())
       .setVisibility(funcOp.getVisibility());
 
@@ -477,16 +491,20 @@ namespace {
 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
 /// information.
-struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
-  FuncOpConversion(const LLVMTypeConverter &converter)
-      : ConvertOpToLLVMPattern(converter) {}
+class FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
+  SymbolTableCollection *symbolTables = nullptr;
+
+public:
+  explicit FuncOpConversion(const LLVMTypeConverter &converter,
+                            SymbolTableCollection *symbolTables = nullptr)
+      : ConvertOpToLLVMPattern(converter), symbolTables(symbolTables) {}
 
   LogicalResult
   matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
         cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
-        *getTypeConverter());
+        *getTypeConverter(), symbolTables);
     if (failed(newFuncOp))
       return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
 
@@ -595,11 +613,12 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
 
 class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
 public:
-  CallOpLowering(const LLVMTypeConverter &typeConverter,
-                 // Can be nullptr.
-                 const SymbolTable *symbolTable, PatternBenefit benefit = 1)
+  explicit CallOpLowering(const LLVMTypeConverter &typeConverter,
+                          // Can be nullptr.
+                          SymbolTableCollection *symbolTables = nullptr,
+                          PatternBenefit benefit = 1)
       : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
-        symbolTable(symbolTable) {}
+        symbolTables(symbolTables) {}
 
   LogicalResult
   matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
@@ -607,10 +626,10 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
     bool useBarePtrCallConv = false;
     if (getTypeConverter()->getOptions().useBarePtrCallConv) {
       useBarePtrCallConv = true;
-    } else if (symbolTable != nullptr) {
+    } else if (symbolTables != nullptr) {
       // Fast lookup.
       Operation *callee =
-          symbolTable->lookup(callOp.getCalleeAttr().getValue());
+          symbolTables->lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
       useBarePtrCallConv =
           callee != nullptr && callee->hasAttr(barePtrAttrName);
     } else {
@@ -624,7 +643,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
   }
 
 private:
-  const SymbolTable *symbolTable = nullptr;
+  SymbolTableCollection *symbolTables = nullptr;
 };
 
 struct CallIndirectOpLowering
@@ -735,16 +754,17 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
 } // namespace
 
 void mlir::populateFuncToLLVMFuncOpConversionPattern(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<FuncOpConversion>(converter);
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    SymbolTableCollection *symbolTables) {
+  patterns.add<FuncOpConversion>(converter, symbolTables);
 }
 
 void mlir::populateFuncToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    const SymbolTable *symbolTable) {
-  populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
+    SymbolTableCollection *symbolTables) {
+  populateFuncToLLVMFuncOpConversionPattern(converter, patterns, symbolTables);
   patterns.add<CallIndirectOpLowering>(converter);
-  patterns.add<CallOpLowering>(converter, symbolTable);
+  patterns.add<CallOpLowering>(converter, symbolTables);
   patterns.add<ConstantOpLowering>(converter);
   patterns.add<ReturnOpLowering>(converter);
 }
@@ -784,15 +804,11 @@ struct ConvertFuncToLLVMPass
     LLVMTypeConverter typeConverter(&getContext(), options,
                                     &dataLayoutAnalysis);
 
-    std::optional<SymbolTable> optSymbolTable = std::nullopt;
-    const SymbolTable *symbolTable = nullptr;
-...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Michele Scuttari (mscuttari)

Changes

This PR allows to optionally speed up the lookup of symbols by providing a SymbolTableCollection instance to the interested conversion patterns.
It is follow-up on the discussion about symbol / symbol table management carried on Discourse.


Patch is 52.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144032.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h (+4-3)
  • (modified) mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h (+6-3)
  • (modified) mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h (+3-1)
  • (modified) mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h (+3-1)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h (+51-32)
  • (modified) mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp (+8-5)
  • (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+44-28)
  • (modified) mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp (+19-4)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+102-38)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+24-12)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp (+89-49)
diff --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
index 88f18022da9bb..2dfb6b03bcfcd 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
@@ -20,6 +20,7 @@ class DialectRegistry;
 class LLVMTypeConverter;
 class RewritePatternSet;
 class Pass;
+class SymbolTableCollection;
 
 #define GEN_PASS_DECL_CONVERTCONTROLFLOWTOLLVMPASS
 #include "mlir/Conversion/Passes.h.inc"
@@ -39,9 +40,9 @@ void populateControlFlowToLLVMConversionPatterns(
 /// Populate the cf.assert to LLVM conversion pattern. If `abortOnFailure` is
 /// set to false, the program execution continues when a condition is
 /// unsatisfied.
-void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter,
-                                           RewritePatternSet &patterns,
-                                           bool abortOnFailure = true);
+void populateAssertToLLVMConversionPattern(
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    bool abortOnFailure = true, SymbolTableCollection *symbolTables = nullptr);
 
 void registerConvertControlFlowToLLVMInterface(DialectRegistry &registry);
 
diff --git a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
index b1ea2740c0605..e530b0a43b8e0 100644
--- a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
+++ b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
@@ -27,20 +27,23 @@ class DialectRegistry;
 class LLVMTypeConverter;
 class RewritePatternSet;
 class SymbolTable;
+class SymbolTableCollection;
 
 /// Convert input FunctionOpInterface operation to LLVMFuncOp by using the
 /// provided LLVMTypeConverter. Return failure if failed to so.
 FailureOr<LLVM::LLVMFuncOp>
 convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                           ConversionPatternRewriter &rewriter,
-                          const LLVMTypeConverter &converter);
+                          const LLVMTypeConverter &converter,
+                          SymbolTableCollection *symbolTables = nullptr);
 
 /// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
 /// `emitCWrappers` is set, the pattern will also produce functions
 /// that pass memref descriptors by pointer-to-structure in addition to the
 /// default unpacked form.
 void populateFuncToLLVMFuncOpConversionPattern(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    SymbolTableCollection *symbolTables = nullptr);
 
 /// Collect the patterns to convert from the Func dialect to LLVM. The
 /// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions
@@ -57,7 +60,7 @@ void populateFuncToLLVMFuncOpConversionPattern(
 /// not an error to provide it anyway.
 void populateFuncToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    const SymbolTable *symbolTable = nullptr);
+    SymbolTableCollection *symbolTables = nullptr);
 
 void registerConvertFuncToLLVMInterface(DialectRegistry &registry);
 
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
index 33402301115b7..d7de40555bb6a 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -17,6 +17,7 @@ namespace mlir {
 
 class OpBuilder;
 class LLVMTypeConverter;
+class SymbolTableCollection;
 
 namespace LLVM {
 
@@ -26,7 +27,8 @@ namespace LLVM {
 LogicalResult createPrintStrCall(
     OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
     StringRef string, const LLVMTypeConverter &typeConverter,
-    bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {});
+    bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {},
+    SymbolTableCollection *symbolTables = nullptr);
 } // namespace LLVM
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
index 996a64baf9dd5..e93d5bdce7bf2 100644
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
+++ b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
@@ -16,6 +16,7 @@ class DialectRegistry;
 class Pass;
 class LLVMTypeConverter;
 class RewritePatternSet;
+class SymbolTableCollection;
 
 #define GEN_PASS_DECL_FINALIZEMEMREFTOLLVMCONVERSIONPASS
 #include "mlir/Conversion/Passes.h.inc"
@@ -23,7 +24,8 @@ class RewritePatternSet;
 /// Collect a set of patterns to convert memory-related operations from the
 /// MemRef dialect to the LLVM dialect.
 void populateFinalizeMemRefToLLVMConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    SymbolTableCollection *symbolTables = nullptr);
 
 void registerConvertMemRefToLLVMInterface(DialectRegistry &registry);
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 4a7ec6f2efe64..8ad9ed18acebd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -24,6 +24,7 @@ class OpBuilder;
 class Operation;
 class Type;
 class ValueRange;
+class SymbolTableCollection;
 
 namespace LLVM {
 class LLVMFuncOp;
@@ -33,55 +34,73 @@ class LLVMFuncOp;
 /// implemented separately (e.g. as part of a support runtime library or as part
 /// of the libc).
 /// Failure if an unexpected version of function is found.
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(OpBuilder &b,
-                                                     Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(OpBuilder &b,
-                                                     Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(OpBuilder &b,
-                                                     Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(OpBuilder &b,
-                                                      Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(OpBuilder &b,
-                                                     Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(OpBuilder &b,
-                                                     Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp,
+                         SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp,
+                         SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp,
+                         SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp,
+                          SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
+                         SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
+                         SymbolTableCollection *symbolTables = nullptr);
 /// Declares a function to print a C-string.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
 /// have the signature void(char const*). The default function is `printString`.
 FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreatePrintStringFn(OpBuilder &b, Operation *moduleOp,
-                            std::optional<StringRef> runtimeFunctionName = {});
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(OpBuilder &b,
-                                                      Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(OpBuilder &b,
-                                                       Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(OpBuilder &b,
-                                                       Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(OpBuilder &b,
-                                                         Operation *moduleOp);
+                            std::optional<StringRef> runtimeFunctionName = {},
+                            SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp,
+                          SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp,
+                           SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp,
+                           SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp,
+                             SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
+                       SymbolTableCollection *symbolTables = nullptr);
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
+lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
+                             SymbolTableCollection *symbolTables = nullptr);
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(OpBuilder &b,
-                                                 Operation *moduleOp);
+lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp,
+                     SymbolTableCollection *symbolTables = nullptr);
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
+lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
+                             SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAlignedAllocFn(
+    OpBuilder &b, Operation *moduleOp, Type indexType,
+    SymbolTableCollection *symbolTables = nullptr);
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp,
-                                    Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(OpBuilder &b,
-                                                        Operation *moduleOp);
+lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp,
+                            SymbolTableCollection *symbolTables = nullptr);
 FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType,
-                           Type unrankedDescriptorType);
+                           Type unrankedDescriptorType,
+                           SymbolTableCollection *symbolTables = nullptr);
 
 /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
 /// Return a failure if the FuncOp found has unexpected signature.
 FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
                  ArrayRef<Type> paramTypes = {}, Type resultType = {},
-                 bool isVarArg = false, bool isReserved = false);
+                 bool isVarArg = false, bool isReserved = false,
+                 SymbolTableCollection *symbolTables = nullptr);
 
 } // namespace LLVM
 } // namespace mlir
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index debfd003bd5b5..eaa8e7d26d4bd 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -44,9 +44,10 @@ namespace {
 /// lowering.
 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
   explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
-                            bool abortOnFailedAssert = true)
+                            bool abortOnFailedAssert = true,
+                            SymbolTableCollection *symbolTables = nullptr)
       : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
-        abortOnFailedAssert(abortOnFailedAssert) {}
+        abortOnFailedAssert(abortOnFailedAssert), symbolTables(symbolTables) {}
 
   LogicalResult
   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
@@ -64,7 +65,7 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
     auto createResult = LLVM::createPrintStrCall(
         rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
         /*addNewLine=*/false,
-        /*runtimeFunctionName=*/"puts");
+        /*runtimeFunctionName=*/"puts", symbolTables);
     if (createResult.failed())
       return failure();
 
@@ -96,6 +97,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
   /// If set to `false`, messages are printed but program execution continues.
   /// This is useful for testing asserts.
   bool abortOnFailedAssert = true;
+
+  SymbolTableCollection *symbolTables = nullptr;
 };
 
 /// Helper function for converting branch ops. This function converts the
@@ -227,8 +230,8 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
 
 void mlir::cf::populateAssertToLLVMConversionPattern(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    bool abortOnFailure) {
-  patterns.add<AssertOpLowering>(converter, abortOnFailure);
+    bool abortOnFailure, SymbolTableCollection *symbolTables) {
+  patterns.add<AssertOpLowering>(converter, abortOnFailure, symbolTables);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 328c605add65c..6a6371921c1d5 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -299,10 +299,9 @@ static void restoreByValRefArgumentType(
   }
 }
 
-FailureOr<LLVM::LLVMFuncOp>
-mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
-                                ConversionPatternRewriter &rewriter,
-                                const LLVMTypeConverter &converter) {
+FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
+    FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter,
+    const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables) {
   // Check the funcOp has `FunctionType`.
   auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
   if (!funcTy)
@@ -365,10 +364,25 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
 
   SmallVector<NamedAttribute, 4> attributes;
   filterFuncAttributes(funcOp, attributes);
+
+  Operation *symbolTableOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
+
+  if (symbolTables && symbolTableOp) {
+    SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+    symbolTable.remove(funcOp);
+  }
+
   auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
       funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
       /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
       attributes);
+
+  if (symbolTables && symbolTableOp) {
+    auto ip = rewriter.getInsertionPoint();
+    SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+    symbolTable.insert(newFuncOp, ip);
+  }
+
   cast<FunctionOpInterface>(newFuncOp.getOperation())
       .setVisibility(funcOp.getVisibility());
 
@@ -477,16 +491,20 @@ namespace {
 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
 /// information.
-struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
-  FuncOpConversion(const LLVMTypeConverter &converter)
-      : ConvertOpToLLVMPattern(converter) {}
+class FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
+  SymbolTableCollection *symbolTables = nullptr;
+
+public:
+  explicit FuncOpConversion(const LLVMTypeConverter &converter,
+                            SymbolTableCollection *symbolTables = nullptr)
+      : ConvertOpToLLVMPattern(converter), symbolTables(symbolTables) {}
 
   LogicalResult
   matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
         cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
-        *getTypeConverter());
+        *getTypeConverter(), symbolTables);
     if (failed(newFuncOp))
       return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
 
@@ -595,11 +613,12 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
 
 class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
 public:
-  CallOpLowering(const LLVMTypeConverter &typeConverter,
-                 // Can be nullptr.
-                 const SymbolTable *symbolTable, PatternBenefit benefit = 1)
+  explicit CallOpLowering(const LLVMTypeConverter &typeConverter,
+                          // Can be nullptr.
+                          SymbolTableCollection *symbolTables = nullptr,
+                          PatternBenefit benefit = 1)
       : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
-        symbolTable(symbolTable) {}
+        symbolTables(symbolTables) {}
 
   LogicalResult
   matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
@@ -607,10 +626,10 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
     bool useBarePtrCallConv = false;
     if (getTypeConverter()->getOptions().useBarePtrCallConv) {
       useBarePtrCallConv = true;
-    } else if (symbolTable != nullptr) {
+    } else if (symbolTables != nullptr) {
       // Fast lookup.
       Operation *callee =
-          symbolTable->lookup(callOp.getCalleeAttr().getValue());
+          symbolTables->lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
       useBarePtrCallConv =
           callee != nullptr && callee->hasAttr(barePtrAttrName);
     } else {
@@ -624,7 +643,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
   }
 
 private:
-  const SymbolTable *symbolTable = nullptr;
+  SymbolTableCollection *symbolTables = nullptr;
 };
 
 struct CallIndirectOpLowering
@@ -735,16 +754,17 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
 } // namespace
 
 void mlir::populateFuncToLLVMFuncOpConversionPattern(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<FuncOpConversion>(converter);
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    SymbolTableCollection *symbolTables) {
+  patterns.add<FuncOpConversion>(converter, symbolTables);
 }
 
 void mlir::populateFuncToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    const SymbolTable *symbolTable) {
-  populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
+    SymbolTableCollection *symbolTables) {
+  populateFuncToLLVMFuncOpConversionPattern(converter, patterns, symbolTables);
   patterns.add<CallIndirectOpLowering>(converter);
-  patterns.add<CallOpLowering>(converter, symbolTable);
+  patterns.add<CallOpLowering>(converter, symbolTables);
   patterns.add<ConstantOpLowering>(converter);
   patterns.add<ReturnOpLowering>(converter);
 }
@@ -784,15 +804,11 @@ struct ConvertFuncToLLVMPass
     LLVMTypeConverter typeConverter(&getContext(), options,
                                     &dataLayoutAnalysis);
 
-    std::optional<SymbolTable> optSymbolTable = std::nullopt;
-    const SymbolTable *symbolTable = nullptr;
-...
[truncated]

@Dinistro Dinistro self-requested a review June 16, 2025 11:31
Comment on lines +37 to +54
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for you to fix, but this having a separate function call for each function name instead of using an enum is what made this change so big :(

@mscuttari mscuttari merged commit bb37296 into llvm:main Jun 21, 2025
7 checks passed
Jaddyen pushed a commit to Jaddyen/llvm-project that referenced this pull request Jun 23, 2025
…lvm#144032)

This PR allows to optionally speed up the lookup of symbols by providing a `SymbolTableCollection` instance to the interested conversion patterns. It is follow-up on the discussion about symbol / symbol table management carried on [Discourse](https://discourse.llvm.org/t/symbol-table-as-first-class-citizen-in-builders/86813).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants