Skip to content

Commit 5f1c7e2

Browse files
committed
[mlir] Use SymbolTableCollection to lookup referenced symbol in AddressOfOp
Depends On D131285 Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D131291
1 parent 3fa291f commit 5f1c7e2

File tree

5 files changed

+18
-9
lines changed

5 files changed

+18
-9
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,10 +1031,10 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
10311031
let extraClassDeclaration = [{
10321032
/// Return the llvm.mlir.global operation that defined the value referenced
10331033
/// here.
1034-
GlobalOp getGlobal();
1034+
GlobalOp getGlobal(SymbolTableCollection &symbolTable);
10351035

10361036
/// Return the llvm.func operation that is referenced here.
1037-
LLVMFuncOp getFunction();
1037+
LLVMFuncOp getFunction(SymbolTableCollection &symbolTable);
10381038
}];
10391039

10401040
let assemblyFormat = "$global_name attr-dict `:` type($res)";

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "mlir/IR/Operation.h"
1818
#include "mlir/IR/Value.h"
19+
#include "mlir/IR/SymbolTable.h"
1920
#include "mlir/Target/LLVMIR/Export.h"
2021
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
2122
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
@@ -264,6 +265,8 @@ class ModuleTranslation {
264265
ModuleTranslation &moduleTranslation;
265266
};
266267

268+
SymbolTableCollection& symbolTable() { return symbolTableCollection; }
269+
267270
private:
268271
ModuleTranslation(Operation *module,
269272
std::unique_ptr<llvm::Module> llvmModule);
@@ -333,6 +336,9 @@ class ModuleTranslation {
333336
/// Stack of user-specified state elements, useful when translating operations
334337
/// with regions.
335338
SmallVector<std::unique_ptr<StackFrame>> stack;
339+
340+
/// A cache for the symbol tables constructed during symbols lookup.
341+
SymbolTableCollection symbolTableCollection;
336342
};
337343

338344
namespace detail {

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1737,14 +1737,14 @@ static Operation *parentLLVMModule(Operation *op) {
17371737
return module;
17381738
}
17391739

1740-
GlobalOp AddressOfOp::getGlobal() {
1740+
GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) {
17411741
return dyn_cast_or_null<GlobalOp>(
1742-
SymbolTable::lookupSymbolIn(parentLLVMModule(*this), getGlobalName()));
1742+
symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
17431743
}
17441744

1745-
LLVMFuncOp AddressOfOp::getFunction() {
1745+
LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) {
17461746
return dyn_cast_or_null<LLVMFuncOp>(
1747-
SymbolTable::lookupSymbolIn(parentLLVMModule(*this), getGlobalName()));
1747+
symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
17481748
}
17491749

17501750
LogicalResult

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,8 +466,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
466466
// operation and store it in the MLIR-to-LLVM value mapping. This does not
467467
// emit any LLVM instruction.
468468
if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
469-
LLVM::GlobalOp global = addressOfOp.getGlobal();
470-
LLVM::LLVMFuncOp function = addressOfOp.getFunction();
469+
LLVM::GlobalOp global =
470+
addressOfOp.getGlobal(moduleTranslation.symbolTable());
471+
LLVM::LLVMFuncOp function =
472+
addressOfOp.getFunction(moduleTranslation.symbolTable());
471473

472474
// The verifier should not have allowed this.
473475
assert((global || function) &&

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1285,7 +1285,8 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
12851285
return opInst.emitError("Addressing symbol not found");
12861286
LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
12871287

1288-
LLVM::GlobalOp global = addressOfOp.getGlobal();
1288+
LLVM::GlobalOp global =
1289+
addressOfOp.getGlobal(moduleTranslation.symbolTable());
12891290
llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
12901291
llvm::Value *data =
12911292
builder.CreateBitCast(globalValue, builder.getInt8PtrTy());

0 commit comments

Comments
 (0)