Skip to content

[MLIR] Add optional cached symbol tables to LLVM conversion patterns #144032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class DialectRegistry;
class LLVMTypeConverter;
class RewritePatternSet;
class Pass;
class SymbolTableCollection;

#define GEN_PASS_DECL_CONVERTCONTROLFLOWTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
Expand All @@ -39,9 +40,9 @@ void populateControlFlowToLLVMConversionPatterns(
/// Populate the cf.assert to LLVM conversion pattern. If `abortOnFailure` is
/// set to false, the program execution continues when a condition is
/// unsatisfied.
void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
bool abortOnFailure = true);
void populateAssertToLLVMConversionPattern(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool abortOnFailure = true, SymbolTableCollection *symbolTables = nullptr);

void registerConvertControlFlowToLLVMInterface(DialectRegistry &registry);

Expand Down
9 changes: 6 additions & 3 deletions mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,23 @@ class DialectRegistry;
class LLVMTypeConverter;
class RewritePatternSet;
class SymbolTable;
class SymbolTableCollection;

/// Convert input FunctionOpInterface operation to LLVMFuncOp by using the
/// provided LLVMTypeConverter. Return failure if failed to so.
FailureOr<LLVM::LLVMFuncOp>
convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &converter);
const LLVMTypeConverter &converter,
SymbolTableCollection *symbolTables = nullptr);

/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
/// `emitCWrappers` is set, the pattern will also produce functions
/// that pass memref descriptors by pointer-to-structure in addition to the
/// default unpacked form.
void populateFuncToLLVMFuncOpConversionPattern(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
SymbolTableCollection *symbolTables = nullptr);

/// Collect the patterns to convert from the Func dialect to LLVM. The
/// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions
Expand All @@ -57,7 +60,7 @@ void populateFuncToLLVMFuncOpConversionPattern(
/// not an error to provide it anyway.
void populateFuncToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
const SymbolTable *symbolTable = nullptr);
SymbolTableCollection *symbolTables = nullptr);

void registerConvertFuncToLLVMInterface(DialectRegistry &registry);

Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace mlir {

class OpBuilder;
class LLVMTypeConverter;
class SymbolTableCollection;

namespace LLVM {

Expand All @@ -26,7 +27,8 @@ namespace LLVM {
LogicalResult createPrintStrCall(
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
StringRef string, const LLVMTypeConverter &typeConverter,
bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {});
bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {},
SymbolTableCollection *symbolTables = nullptr);
} // namespace LLVM

} // namespace mlir
Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ class DialectRegistry;
class Pass;
class LLVMTypeConverter;
class RewritePatternSet;
class SymbolTableCollection;

#define GEN_PASS_DECL_FINALIZEMEMREFTOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"

/// Collect a set of patterns to convert memory-related operations from the
/// MemRef dialect to the LLVM dialect.
void populateFinalizeMemRefToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
SymbolTableCollection *symbolTables = nullptr);

void registerConvertMemRefToLLVMInterface(DialectRegistry &registry);

Expand Down
83 changes: 51 additions & 32 deletions mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class OpBuilder;
class Operation;
class Type;
class ValueRange;
class SymbolTableCollection;

namespace LLVM {
class LLVMFuncOp;
Expand All @@ -33,55 +34,73 @@ class LLVMFuncOp;
/// implemented separately (e.g. as part of a support runtime library or as part
/// of the libc).
/// Failure if an unexpected version of function is found.
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(OpBuilder &b,
Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(OpBuilder &b,
Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(OpBuilder &b,
Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(OpBuilder &b,
Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(OpBuilder &b,
Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(OpBuilder &b,
Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
Comment on lines +37 to +54
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for you to fix, but this having a separate function call for each function name instead of using an enum is what made this change so big :(

/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintStringFn(OpBuilder &b, Operation *moduleOp,
std::optional<StringRef> runtimeFunctionName = {});
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(OpBuilder &b,
Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(OpBuilder &b,
Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(OpBuilder &b,
Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(OpBuilder &b,
Operation *moduleOp);
std::optional<StringRef> runtimeFunctionName = {},
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(OpBuilder &b,
Operation *moduleOp);
lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAlignedAllocFn(
OpBuilder &b, Operation *moduleOp, Type indexType,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp,
Type indexType);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(OpBuilder &b,
Operation *moduleOp);
lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType,
Type unrankedDescriptorType);
Type unrankedDescriptorType,
SymbolTableCollection *symbolTables = nullptr);

/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
/// Return a failure if the FuncOp found has unexpected signature.
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {}, Type resultType = {},
bool isVarArg = false, bool isReserved = false);
bool isVarArg = false, bool isReserved = false,
SymbolTableCollection *symbolTables = nullptr);

} // namespace LLVM
} // namespace mlir
Expand Down
13 changes: 8 additions & 5 deletions mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ namespace {
/// lowering.
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
bool abortOnFailedAssert = true)
bool abortOnFailedAssert = true,
SymbolTableCollection *symbolTables = nullptr)
: ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
abortOnFailedAssert(abortOnFailedAssert) {}
abortOnFailedAssert(abortOnFailedAssert), symbolTables(symbolTables) {}

LogicalResult
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
Expand All @@ -64,7 +65,7 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
auto createResult = LLVM::createPrintStrCall(
rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
/*addNewLine=*/false,
/*runtimeFunctionName=*/"puts");
/*runtimeFunctionName=*/"puts", symbolTables);
if (createResult.failed())
return failure();

Expand Down Expand Up @@ -96,6 +97,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
/// If set to `false`, messages are printed but program execution continues.
/// This is useful for testing asserts.
bool abortOnFailedAssert = true;

SymbolTableCollection *symbolTables = nullptr;
};

/// Helper function for converting branch ops. This function converts the
Expand Down Expand Up @@ -232,8 +235,8 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(

void mlir::cf::populateAssertToLLVMConversionPattern(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool abortOnFailure) {
patterns.add<AssertOpLowering>(converter, abortOnFailure);
bool abortOnFailure, SymbolTableCollection *symbolTables) {
patterns.add<AssertOpLowering>(converter, abortOnFailure, symbolTables);
}

//===----------------------------------------------------------------------===//
Expand Down
71 changes: 43 additions & 28 deletions mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,9 @@ static void restoreByValRefArgumentType(
}
}

FailureOr<LLVM::LLVMFuncOp>
mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &converter) {
FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables) {
// Check the funcOp has `FunctionType`.
auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
if (!funcTy)
Expand Down Expand Up @@ -361,10 +360,25 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,

SmallVector<NamedAttribute, 4> attributes;
filterFuncAttributes(funcOp, attributes);

Operation *symbolTableOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();

if (symbolTables && symbolTableOp) {
SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
symbolTable.remove(funcOp);
}

auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
attributes);

if (symbolTables && symbolTableOp) {
auto ip = rewriter.getInsertionPoint();
SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
symbolTable.insert(newFuncOp, ip);
}

cast<FunctionOpInterface>(newFuncOp.getOperation())
.setVisibility(funcOp.getVisibility());

Expand Down Expand Up @@ -473,16 +487,20 @@ namespace {
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information.
struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
FuncOpConversion(const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern(converter) {}
class FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
SymbolTableCollection *symbolTables = nullptr;

public:
explicit FuncOpConversion(const LLVMTypeConverter &converter,
SymbolTableCollection *symbolTables = nullptr)
: ConvertOpToLLVMPattern(converter), symbolTables(symbolTables) {}

LogicalResult
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
*getTypeConverter());
*getTypeConverter(), symbolTables);
if (failed(newFuncOp))
return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");

Expand Down Expand Up @@ -591,22 +609,22 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {

class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
public:
CallOpLowering(const LLVMTypeConverter &typeConverter,
// Can be nullptr.
const SymbolTable *symbolTable, PatternBenefit benefit = 1)
explicit CallOpLowering(const LLVMTypeConverter &typeConverter,
SymbolTableCollection *symbolTables = nullptr,
PatternBenefit benefit = 1)
: CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
symbolTable(symbolTable) {}
symbolTables(symbolTables) {}

LogicalResult
matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool useBarePtrCallConv = false;
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
useBarePtrCallConv = true;
} else if (symbolTable != nullptr) {
} else if (symbolTables != nullptr) {
// Fast lookup.
Operation *callee =
symbolTable->lookup(callOp.getCalleeAttr().getValue());
symbolTables->lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
useBarePtrCallConv =
callee != nullptr && callee->hasAttr(barePtrAttrName);
} else {
Expand All @@ -620,7 +638,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
}

private:
const SymbolTable *symbolTable = nullptr;
SymbolTableCollection *symbolTables = nullptr;
};

struct CallIndirectOpLowering
Expand Down Expand Up @@ -731,16 +749,17 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
} // namespace

void mlir::populateFuncToLLVMFuncOpConversionPattern(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<FuncOpConversion>(converter);
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
SymbolTableCollection *symbolTables) {
patterns.add<FuncOpConversion>(converter, symbolTables);
}

void mlir::populateFuncToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
const SymbolTable *symbolTable) {
populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
SymbolTableCollection *symbolTables) {
populateFuncToLLVMFuncOpConversionPattern(converter, patterns, symbolTables);
patterns.add<CallIndirectOpLowering>(converter);
patterns.add<CallOpLowering>(converter, symbolTable);
patterns.add<CallOpLowering>(converter, symbolTables);
patterns.add<ConstantOpLowering>(converter);
patterns.add<ReturnOpLowering>(converter);
}
Expand Down Expand Up @@ -780,15 +799,11 @@ struct ConvertFuncToLLVMPass
LLVMTypeConverter typeConverter(&getContext(), options,
&dataLayoutAnalysis);

std::optional<SymbolTable> optSymbolTable = std::nullopt;
const SymbolTable *symbolTable = nullptr;
if (!options.useBarePtrCallConv) {
optSymbolTable.emplace(m);
symbolTable = &optSymbolTable.value();
}

RewritePatternSet patterns(&getContext());
populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
SymbolTableCollection symbolTables;

populateFuncToLLVMConversionPatterns(typeConverter, patterns,
&symbolTables);

LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(m, target, std::move(patterns))))
Expand Down
Loading