Skip to content

Commit d47cd50

Browse files
[MLIR][LLVM] Fix blockaddress mapping to LLVM blocks (#139814)
After each function is translated, both value and block maps are erased, which makes the current mapping of blockaddresses to llvm blocks broken - the patching happens only after *all* functions are translated. Simplify the overall mapping, update comments, variable names and fix the bug. --------- Co-authored-by: Christian Ulmann <[email protected]>
1 parent 539265b commit d47cd50

File tree

4 files changed

+48
-30
lines changed

4 files changed

+48
-30
lines changed

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -142,21 +142,20 @@ class ModuleTranslation {
142142
auto result = unresolvedBlockAddressMapping.try_emplace(op, cst);
143143
(void)result;
144144
assert(result.second &&
145-
"attempting to map a blockaddress that is already mapped");
145+
"attempting to map a blockaddress operation that is already mapped");
146146
}
147147

148-
/// Maps a blockaddress operation to its corresponding placeholder LLVM
149-
/// value.
150-
void mapBlockTag(BlockAddressAttr attr, BlockTagOp blockTag) {
151-
// Attempts to map already mapped block labels which is fine if the given
152-
// labels are verified to be unique.
153-
blockTagMapping[attr] = blockTag;
148+
/// Maps a BlockAddressAttr to its corresponding LLVM basic block.
149+
void mapBlockAddress(BlockAddressAttr attr, llvm::BasicBlock *block) {
150+
auto result = blockAddressToLLVMMapping.try_emplace(attr, block);
151+
(void)result;
152+
assert(result.second &&
153+
"attempting to map a blockaddress attribute that is already mapped");
154154
}
155155

156-
/// Finds an MLIR block that corresponds to the given MLIR call
157-
/// operation.
158-
BlockTagOp lookupBlockTag(BlockAddressAttr attr) const {
159-
return blockTagMapping.lookup(attr);
156+
/// Finds the LLVM basic block that corresponds to the given BlockAddressAttr.
157+
llvm::BasicBlock *lookupBlockAddress(BlockAddressAttr attr) const {
158+
return blockAddressToLLVMMapping.lookup(attr);
160159
}
161160

162161
/// Removes the mapping for blocks contained in the region and values defined
@@ -463,10 +462,9 @@ class ModuleTranslation {
463462
/// mapping is used to replace the placeholders with the LLVM block addresses.
464463
DenseMap<BlockAddressOp, llvm::Value *> unresolvedBlockAddressMapping;
465464

466-
/// Mapping from a BlockAddressAttr attribute to a matching BlockTagOp. This
467-
/// is used to cache BlockTagOp locations instead of walking a LLVMFuncOp in
468-
/// search for those.
469-
DenseMap<BlockAddressAttr, BlockTagOp> blockTagMapping;
465+
/// Mapping from a BlockAddressAttr attribute to it's matching LLVM basic
466+
/// block.
467+
DenseMap<BlockAddressAttr, llvm::BasicBlock *> blockAddressToLLVMMapping;
470468

471469
/// Stack of user-specified state elements, useful when translating operations
472470
/// with regions.

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -690,19 +690,13 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
690690
// Emit blockaddress. We first need to find the LLVM block referenced by this
691691
// operation and then create a LLVM block address for it.
692692
if (auto blockAddressOp = dyn_cast<LLVM::BlockAddressOp>(opInst)) {
693-
// getBlockTagOp() walks a function to search for block labels. Check
694-
// whether it's in cache first.
695693
BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr();
696-
BlockTagOp blockTagOp = moduleTranslation.lookupBlockTag(blockAddressAttr);
697-
if (!blockTagOp) {
698-
blockTagOp = blockAddressOp.getBlockTagOp();
699-
moduleTranslation.mapBlockTag(blockAddressAttr, blockTagOp);
700-
}
694+
llvm::BasicBlock *llvmBlock =
695+
moduleTranslation.lookupBlockAddress(blockAddressAttr);
701696

702697
llvm::Value *llvmValue = nullptr;
703698
StringRef fnName = blockAddressAttr.getFunction().getValue();
704-
if (llvm::BasicBlock *llvmBlock =
705-
moduleTranslation.lookupBlock(blockTagOp->getBlock())) {
699+
if (llvmBlock) {
706700
llvm::Function *llvmFn = moduleTranslation.lookupFunction(fnName);
707701
llvmValue = llvm::BlockAddress::get(llvmFn, llvmBlock);
708702
} else {
@@ -736,7 +730,8 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
736730
FlatSymbolRefAttr::get(&moduleTranslation.getContext(),
737731
funcOp.getName()),
738732
blockTagOp.getTag());
739-
moduleTranslation.mapBlockTag(blockAddressAttr, blockTagOp);
733+
moduleTranslation.mapBlockAddress(blockAddressAttr,
734+
builder.GetInsertBlock());
740735
return success();
741736
}
742737

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,17 +1843,13 @@ LogicalResult ModuleTranslation::convertComdats() {
18431843
LogicalResult ModuleTranslation::convertUnresolvedBlockAddress() {
18441844
for (auto &[blockAddressOp, llvmCst] : unresolvedBlockAddressMapping) {
18451845
BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr();
1846-
BlockTagOp blockTagOp = lookupBlockTag(blockAddressAttr);
1847-
assert(blockTagOp && "expected all block tags to be already seen");
1848-
1849-
llvm::BasicBlock *llvmBlock = lookupBlock(blockTagOp->getBlock());
1846+
llvm::BasicBlock *llvmBlock = lookupBlockAddress(blockAddressAttr);
18501847
assert(llvmBlock && "expected LLVM blocks to be already translated");
18511848

18521849
// Update mapping with new block address constant.
18531850
auto *llvmBlockAddr = llvm::BlockAddress::get(
18541851
lookupFunction(blockAddressAttr.getFunction().getValue()), llvmBlock);
18551852
llvmCst->replaceAllUsesWith(llvmBlockAddr);
1856-
mapValue(blockAddressOp.getResult(), llvmBlockAddr);
18571853
assert(llvmCst->use_empty() && "expected all uses to be replaced");
18581854
cast<llvm::GlobalVariable>(llvmCst)->eraseFromParent();
18591855
}

mlir/test/Target/LLVMIR/blockaddress.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,32 @@ llvm.func @blockaddr0() -> !llvm.ptr {
3434
// CHECK: [[RET]]:
3535
// CHECK: ret ptr blockaddress(@blockaddr0, %1)
3636
// CHECK: }
37+
38+
// -----
39+
40+
llvm.mlir.global private @h() {addr_space = 0 : i32, dso_local} : !llvm.ptr {
41+
%0 = llvm.blockaddress <function = @h3, tag = <id = 0>> : !llvm.ptr
42+
llvm.return %0 : !llvm.ptr
43+
}
44+
45+
// CHECK: @h = private global ptr blockaddress(@h3, %[[BB_ADDR:.*]])
46+
47+
// CHECK: define void @h3() {
48+
// CHECK: br label %[[BB_ADDR]]
49+
50+
// CHECK: [[BB_ADDR]]:
51+
// CHECK: ret void
52+
// CHECK: }
53+
54+
// CHECK: define void @h0()
55+
56+
llvm.func @h3() {
57+
llvm.br ^bb1
58+
^bb1:
59+
llvm.blocktag <id = 0>
60+
llvm.return
61+
}
62+
63+
llvm.func @h0() {
64+
llvm.return
65+
}

0 commit comments

Comments
 (0)