Skip to content

Commit bb37296

Browse files
authored
[MLIR] Add optional cached symbol tables to LLVM conversion patterns (#144032)
This PR allows to optionally speed up the lookup of symbols by providing a `SymbolTableCollection` instance to the interested conversion patterns. It is follow-up on the discussion about symbol / symbol table management carried on [Discourse](https://discourse.llvm.org/t/symbol-table-as-first-class-citizen-in-builders/86813).
1 parent 0921bfd commit bb37296

File tree

11 files changed

+359
-176
lines changed

11 files changed

+359
-176
lines changed

mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class DialectRegistry;
2020
class LLVMTypeConverter;
2121
class RewritePatternSet;
2222
class Pass;
23+
class SymbolTableCollection;
2324

2425
#define GEN_PASS_DECL_CONVERTCONTROLFLOWTOLLVMPASS
2526
#include "mlir/Conversion/Passes.h.inc"
@@ -39,9 +40,9 @@ void populateControlFlowToLLVMConversionPatterns(
3940
/// Populate the cf.assert to LLVM conversion pattern. If `abortOnFailure` is
4041
/// set to false, the program execution continues when a condition is
4142
/// unsatisfied.
42-
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter,
43-
RewritePatternSet &patterns,
44-
bool abortOnFailure = true);
43+
void populateAssertToLLVMConversionPattern(
44+
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
45+
bool abortOnFailure = true, SymbolTableCollection *symbolTables = nullptr);
4546

4647
void registerConvertControlFlowToLLVMInterface(DialectRegistry &registry);
4748

mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,23 @@ class DialectRegistry;
2727
class LLVMTypeConverter;
2828
class RewritePatternSet;
2929
class SymbolTable;
30+
class SymbolTableCollection;
3031

3132
/// Convert input FunctionOpInterface operation to LLVMFuncOp by using the
3233
/// provided LLVMTypeConverter. Return failure if failed to so.
3334
FailureOr<LLVM::LLVMFuncOp>
3435
convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
3536
ConversionPatternRewriter &rewriter,
36-
const LLVMTypeConverter &converter);
37+
const LLVMTypeConverter &converter,
38+
SymbolTableCollection *symbolTables = nullptr);
3739

3840
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
3941
/// `emitCWrappers` is set, the pattern will also produce functions
4042
/// that pass memref descriptors by pointer-to-structure in addition to the
4143
/// default unpacked form.
4244
void populateFuncToLLVMFuncOpConversionPattern(
43-
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
45+
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
46+
SymbolTableCollection *symbolTables = nullptr);
4447

4548
/// Collect the patterns to convert from the Func dialect to LLVM. The
4649
/// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions
@@ -57,7 +60,7 @@ void populateFuncToLLVMFuncOpConversionPattern(
5760
/// not an error to provide it anyway.
5861
void populateFuncToLLVMConversionPatterns(
5962
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
60-
const SymbolTable *symbolTable = nullptr);
63+
SymbolTableCollection *symbolTables = nullptr);
6164

6265
void registerConvertFuncToLLVMInterface(DialectRegistry &registry);
6366

mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ namespace mlir {
1717

1818
class OpBuilder;
1919
class LLVMTypeConverter;
20+
class SymbolTableCollection;
2021

2122
namespace LLVM {
2223

@@ -26,7 +27,8 @@ namespace LLVM {
2627
LogicalResult createPrintStrCall(
2728
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
2829
StringRef string, const LLVMTypeConverter &typeConverter,
29-
bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {});
30+
bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {},
31+
SymbolTableCollection *symbolTables = nullptr);
3032
} // namespace LLVM
3133

3234
} // namespace mlir

mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@ class DialectRegistry;
1616
class Pass;
1717
class LLVMTypeConverter;
1818
class RewritePatternSet;
19+
class SymbolTableCollection;
1920

2021
#define GEN_PASS_DECL_FINALIZEMEMREFTOLLVMCONVERSIONPASS
2122
#include "mlir/Conversion/Passes.h.inc"
2223

2324
/// Collect a set of patterns to convert memory-related operations from the
2425
/// MemRef dialect to the LLVM dialect.
2526
void populateFinalizeMemRefToLLVMConversionPatterns(
26-
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
27+
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
28+
SymbolTableCollection *symbolTables = nullptr);
2729

2830
void registerConvertMemRefToLLVMInterface(DialectRegistry &registry);
2931

mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class OpBuilder;
2424
class Operation;
2525
class Type;
2626
class ValueRange;
27+
class SymbolTableCollection;
2728

2829
namespace LLVM {
2930
class LLVMFuncOp;
@@ -33,55 +34,73 @@ class LLVMFuncOp;
3334
/// implemented separately (e.g. as part of a support runtime library or as part
3435
/// of the libc).
3536
/// Failure if an unexpected version of function is found.
36-
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(OpBuilder &b,
37-
Operation *moduleOp);
38-
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(OpBuilder &b,
39-
Operation *moduleOp);
40-
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(OpBuilder &b,
41-
Operation *moduleOp);
42-
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(OpBuilder &b,
43-
Operation *moduleOp);
44-
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(OpBuilder &b,
45-
Operation *moduleOp);
46-
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(OpBuilder &b,
47-
Operation *moduleOp);
37+
FailureOr<LLVM::LLVMFuncOp>
38+
lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp,
39+
SymbolTableCollection *symbolTables = nullptr);
40+
FailureOr<LLVM::LLVMFuncOp>
41+
lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp,
42+
SymbolTableCollection *symbolTables = nullptr);
43+
FailureOr<LLVM::LLVMFuncOp>
44+
lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp,
45+
SymbolTableCollection *symbolTables = nullptr);
46+
FailureOr<LLVM::LLVMFuncOp>
47+
lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp,
48+
SymbolTableCollection *symbolTables = nullptr);
49+
FailureOr<LLVM::LLVMFuncOp>
50+
lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
51+
SymbolTableCollection *symbolTables = nullptr);
52+
FailureOr<LLVM::LLVMFuncOp>
53+
lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
54+
SymbolTableCollection *symbolTables = nullptr);
4855
/// Declares a function to print a C-string.
4956
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
5057
/// have the signature void(char const*). The default function is `printString`.
5158
FailureOr<LLVM::LLVMFuncOp>
5259
lookupOrCreatePrintStringFn(OpBuilder &b, Operation *moduleOp,
53-
std::optional<StringRef> runtimeFunctionName = {});
54-
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(OpBuilder &b,
55-
Operation *moduleOp);
56-
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(OpBuilder &b,
57-
Operation *moduleOp);
58-
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(OpBuilder &b,
59-
Operation *moduleOp);
60-
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(OpBuilder &b,
61-
Operation *moduleOp);
60+
std::optional<StringRef> runtimeFunctionName = {},
61+
SymbolTableCollection *symbolTables = nullptr);
62+
FailureOr<LLVM::LLVMFuncOp>
63+
lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp,
64+
SymbolTableCollection *symbolTables = nullptr);
65+
FailureOr<LLVM::LLVMFuncOp>
66+
lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp,
67+
SymbolTableCollection *symbolTables = nullptr);
68+
FailureOr<LLVM::LLVMFuncOp>
69+
lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp,
70+
SymbolTableCollection *symbolTables = nullptr);
71+
FailureOr<LLVM::LLVMFuncOp>
72+
lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp,
73+
SymbolTableCollection *symbolTables = nullptr);
74+
FailureOr<LLVM::LLVMFuncOp>
75+
lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
76+
SymbolTableCollection *symbolTables = nullptr);
6277
FailureOr<LLVM::LLVMFuncOp>
63-
lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
78+
lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
79+
SymbolTableCollection *symbolTables = nullptr);
6480
FailureOr<LLVM::LLVMFuncOp>
65-
lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
66-
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(OpBuilder &b,
67-
Operation *moduleOp);
81+
lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp,
82+
SymbolTableCollection *symbolTables = nullptr);
6883
FailureOr<LLVM::LLVMFuncOp>
69-
lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
84+
lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
85+
SymbolTableCollection *symbolTables = nullptr);
86+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAlignedAllocFn(
87+
OpBuilder &b, Operation *moduleOp, Type indexType,
88+
SymbolTableCollection *symbolTables = nullptr);
7089
FailureOr<LLVM::LLVMFuncOp>
71-
lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp,
72-
Type indexType);
73-
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(OpBuilder &b,
74-
Operation *moduleOp);
90+
lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp,
91+
SymbolTableCollection *symbolTables = nullptr);
7592
FailureOr<LLVM::LLVMFuncOp>
7693
lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType,
77-
Type unrankedDescriptorType);
94+
Type unrankedDescriptorType,
95+
SymbolTableCollection *symbolTables = nullptr);
7896

7997
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
8098
/// Return a failure if the FuncOp found has unexpected signature.
8199
FailureOr<LLVM::LLVMFuncOp>
82100
lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
83101
ArrayRef<Type> paramTypes = {}, Type resultType = {},
84-
bool isVarArg = false, bool isReserved = false);
102+
bool isVarArg = false, bool isReserved = false,
103+
SymbolTableCollection *symbolTables = nullptr);
85104

86105
} // namespace LLVM
87106
} // namespace mlir

mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ namespace {
4444
/// lowering.
4545
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
4646
explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
47-
bool abortOnFailedAssert = true)
47+
bool abortOnFailedAssert = true,
48+
SymbolTableCollection *symbolTables = nullptr)
4849
: ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
49-
abortOnFailedAssert(abortOnFailedAssert) {}
50+
abortOnFailedAssert(abortOnFailedAssert), symbolTables(symbolTables) {}
5051

5152
LogicalResult
5253
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
@@ -64,7 +65,7 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
6465
auto createResult = LLVM::createPrintStrCall(
6566
rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
6667
/*addNewLine=*/false,
67-
/*runtimeFunctionName=*/"puts");
68+
/*runtimeFunctionName=*/"puts", symbolTables);
6869
if (createResult.failed())
6970
return failure();
7071

@@ -96,6 +97,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
9697
/// If set to `false`, messages are printed but program execution continues.
9798
/// This is useful for testing asserts.
9899
bool abortOnFailedAssert = true;
100+
101+
SymbolTableCollection *symbolTables = nullptr;
99102
};
100103

101104
/// Helper function for converting branch ops. This function converts the
@@ -232,8 +235,8 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
232235

233236
void mlir::cf::populateAssertToLLVMConversionPattern(
234237
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
235-
bool abortOnFailure) {
236-
patterns.add<AssertOpLowering>(converter, abortOnFailure);
238+
bool abortOnFailure, SymbolTableCollection *symbolTables) {
239+
patterns.add<AssertOpLowering>(converter, abortOnFailure, symbolTables);
237240
}
238241

239242
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,9 @@ static void restoreByValRefArgumentType(
299299
}
300300
}
301301

302-
FailureOr<LLVM::LLVMFuncOp>
303-
mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
304-
ConversionPatternRewriter &rewriter,
305-
const LLVMTypeConverter &converter) {
302+
FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
303+
FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter,
304+
const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables) {
306305
// Check the funcOp has `FunctionType`.
307306
auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
308307
if (!funcTy)
@@ -361,10 +360,25 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
361360

362361
SmallVector<NamedAttribute, 4> attributes;
363362
filterFuncAttributes(funcOp, attributes);
363+
364+
Operation *symbolTableOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
365+
366+
if (symbolTables && symbolTableOp) {
367+
SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
368+
symbolTable.remove(funcOp);
369+
}
370+
364371
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
365372
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
366373
/*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
367374
attributes);
375+
376+
if (symbolTables && symbolTableOp) {
377+
auto ip = rewriter.getInsertionPoint();
378+
SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
379+
symbolTable.insert(newFuncOp, ip);
380+
}
381+
368382
cast<FunctionOpInterface>(newFuncOp.getOperation())
369383
.setVisibility(funcOp.getVisibility());
370384

@@ -473,16 +487,20 @@ namespace {
473487
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
474488
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
475489
/// information.
476-
struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
477-
FuncOpConversion(const LLVMTypeConverter &converter)
478-
: ConvertOpToLLVMPattern(converter) {}
490+
class FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
491+
SymbolTableCollection *symbolTables = nullptr;
492+
493+
public:
494+
explicit FuncOpConversion(const LLVMTypeConverter &converter,
495+
SymbolTableCollection *symbolTables = nullptr)
496+
: ConvertOpToLLVMPattern(converter), symbolTables(symbolTables) {}
479497

480498
LogicalResult
481499
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
482500
ConversionPatternRewriter &rewriter) const override {
483501
FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
484502
cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
485-
*getTypeConverter());
503+
*getTypeConverter(), symbolTables);
486504
if (failed(newFuncOp))
487505
return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
488506

@@ -591,22 +609,22 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
591609

592610
class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
593611
public:
594-
CallOpLowering(const LLVMTypeConverter &typeConverter,
595-
// Can be nullptr.
596-
const SymbolTable *symbolTable, PatternBenefit benefit = 1)
612+
explicit CallOpLowering(const LLVMTypeConverter &typeConverter,
613+
SymbolTableCollection *symbolTables = nullptr,
614+
PatternBenefit benefit = 1)
597615
: CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
598-
symbolTable(symbolTable) {}
616+
symbolTables(symbolTables) {}
599617

600618
LogicalResult
601619
matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
602620
ConversionPatternRewriter &rewriter) const override {
603621
bool useBarePtrCallConv = false;
604622
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
605623
useBarePtrCallConv = true;
606-
} else if (symbolTable != nullptr) {
624+
} else if (symbolTables != nullptr) {
607625
// Fast lookup.
608626
Operation *callee =
609-
symbolTable->lookup(callOp.getCalleeAttr().getValue());
627+
symbolTables->lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
610628
useBarePtrCallConv =
611629
callee != nullptr && callee->hasAttr(barePtrAttrName);
612630
} else {
@@ -620,7 +638,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
620638
}
621639

622640
private:
623-
const SymbolTable *symbolTable = nullptr;
641+
SymbolTableCollection *symbolTables = nullptr;
624642
};
625643

626644
struct CallIndirectOpLowering
@@ -731,16 +749,17 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
731749
} // namespace
732750

733751
void mlir::populateFuncToLLVMFuncOpConversionPattern(
734-
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
735-
patterns.add<FuncOpConversion>(converter);
752+
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
753+
SymbolTableCollection *symbolTables) {
754+
patterns.add<FuncOpConversion>(converter, symbolTables);
736755
}
737756

738757
void mlir::populateFuncToLLVMConversionPatterns(
739758
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
740-
const SymbolTable *symbolTable) {
741-
populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
759+
SymbolTableCollection *symbolTables) {
760+
populateFuncToLLVMFuncOpConversionPattern(converter, patterns, symbolTables);
742761
patterns.add<CallIndirectOpLowering>(converter);
743-
patterns.add<CallOpLowering>(converter, symbolTable);
762+
patterns.add<CallOpLowering>(converter, symbolTables);
744763
patterns.add<ConstantOpLowering>(converter);
745764
patterns.add<ReturnOpLowering>(converter);
746765
}
@@ -780,15 +799,11 @@ struct ConvertFuncToLLVMPass
780799
LLVMTypeConverter typeConverter(&getContext(), options,
781800
&dataLayoutAnalysis);
782801

783-
std::optional<SymbolTable> optSymbolTable = std::nullopt;
784-
const SymbolTable *symbolTable = nullptr;
785-
if (!options.useBarePtrCallConv) {
786-
optSymbolTable.emplace(m);
787-
symbolTable = &optSymbolTable.value();
788-
}
789-
790802
RewritePatternSet patterns(&getContext());
791-
populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
803+
SymbolTableCollection symbolTables;
804+
805+
populateFuncToLLVMConversionPatterns(typeConverter, patterns,
806+
&symbolTables);
792807

793808
LLVMConversionTarget target(getContext());
794809
if (failed(applyPartialConversion(m, target, std::move(patterns))))

0 commit comments

Comments
 (0)