Skip to content

[MLIR][LLVM] Support for indirectbr #135092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 61 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1690,8 +1690,7 @@ def LLVM_BlockTagOp : LLVM_Op<"blocktag"> {
the target address.

A given function should have at most one `llvm.blocktag` operation with a
given `tag`. This operation cannot be used as a terminator but can be
placed everywhere else in a block.
given `tag`. This operation cannot be used as a terminator.

Example:

Expand All @@ -1711,6 +1710,66 @@ def LLVM_BlockTagOp : LLVM_Op<"blocktag"> {
let hasVerifier = 0;
}

//===----------------------------------------------------------------------===//
// IndirectBrOp
//===----------------------------------------------------------------------===//

def LLVM_IndirectBrOp : LLVM_TerminatorOp<"indirectbr",
[SameVariadicOperandSize, DeclareOpInterfaceMethods<BranchOpInterface>,
Pure]> {
let description = [{
Transfer control flow to address in `$addr`. A list of possible target
blocks in `$successors` can be provided and maybe used as a hint in LLVM:

```mlir
...
llvm.func @g(...
%dest = llvm.blockaddress <function = @g, tag = <id = 0>> : !llvm.ptr
llvm.indirectbr %dest : !llvm.ptr, [
^head
]
^head:
llvm.blocktag <id = 0>
llvm.return %arg0 : i32
...
```

It also supports a list of operands that can be passed to a target block:

```mlir
llvm.indirectbr %dest : !llvm.ptr, [
^head(%arg0 : i32),
^tail(%arg1, %arg0 : i32, i32)
]
^head(%r0 : i32):
llvm.return %r0 : i32
^tail(%r1 : i32, %r2 : i32):
...
```
}];
let arguments = (ins LLVM_AnyPointer:$addr,
VariadicOfVariadic<AnyType, "indbr_operand_segments">:$succOperands,
DenseI32ArrayAttr:$indbr_operand_segments
);
let successors = (successor VariadicSuccessor<AnySuccessor>:$successors);
let assemblyFormat = [{
$addr `:` type($addr) `,`
custom<IndirectBrOpSucessors>(ref(type($addr)),
$successors,
$succOperands,
type($succOperands))
attr-dict
}];

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "Value":$addr,
CArg<"ArrayRef<ValueRange>", "{}">:$succOperands,
CArg<"BlockRange", "{}">:$successors
)>
];
}

def LLVM_ComdatSelectorOp : LLVM_Op<"comdat_selector", [Symbol]> {
let arguments = (ins
SymbolNameAttr:$sym_name,
Expand Down
103 changes: 86 additions & 17 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2240,24 +2240,21 @@ static LogicalResult verifyComdat(Operation *op,

static LogicalResult verifyBlockTags(LLVMFuncOp funcOp) {
llvm::DenseSet<BlockTagAttr> blockTags;
BlockTagOp badBlockTagOp;
if (funcOp
.walk([&](BlockTagOp blockTagOp) {
if (blockTags.contains(blockTagOp.getTag())) {
badBlockTagOp = blockTagOp;
return WalkResult::interrupt();
}
blockTags.insert(blockTagOp.getTag());
return WalkResult::advance();
})
.wasInterrupted()) {
badBlockTagOp.emitError()
<< "duplicate block tag '" << badBlockTagOp.getTag().getId()
<< "' in the same function: ";
return failure();
}
// Note that presence of `BlockTagOp`s currently can't prevent an unrecheable
// block to be removed by canonicalizer's region simplify pass, which needs to
// be dialect aware to allow extra constraints to be described.
WalkResult res = funcOp.walk([&](BlockTagOp blockTagOp) {
if (blockTags.contains(blockTagOp.getTag())) {
blockTagOp.emitError()
<< "duplicate block tag '" << blockTagOp.getTag().getId()
<< "' in the same function: ";
return WalkResult::interrupt();
}
blockTags.insert(blockTagOp.getTag());
return WalkResult::advance();
});

return success();
return failure(res.wasInterrupted());
}

/// Parse common attributes that might show up in the same order in both
Expand Down Expand Up @@ -3818,6 +3815,78 @@ LogicalResult BlockAddressOp::verify() {
/// attribute.
OpFoldResult BlockAddressOp::fold(FoldAdaptor) { return getBlockAddr(); }

//===----------------------------------------------------------------------===//
// LLVM::IndirectBrOp
//===----------------------------------------------------------------------===//

SuccessorOperands IndirectBrOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return SuccessorOperands(getSuccOperandsMutable()[index]);
}

void IndirectBrOp::build(OpBuilder &odsBuilder, OperationState &odsState,
Value addr, ArrayRef<ValueRange> succOperands,
BlockRange successors) {
odsState.addOperands(addr);
for (ValueRange range : succOperands)
odsState.addOperands(range);
SmallVector<int32_t> rangeSegments;
for (ValueRange range : succOperands)
rangeSegments.push_back(range.size());
odsState.getOrAddProperties<Properties>().indbr_operand_segments =
odsBuilder.getDenseI32ArrayAttr(rangeSegments);
odsState.addSuccessors(successors);
}

static ParseResult parseIndirectBrOpSucessors(
OpAsmParser &parser, Type &flagType,
SmallVectorImpl<Block *> &succOperandBlocks,
SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &succOperands,
SmallVectorImpl<SmallVector<Type>> &succOperandsTypes) {
if (failed(parser.parseCommaSeparatedList(
OpAsmParser::Delimiter::Square,
[&]() {
Block *destination = nullptr;
SmallVector<OpAsmParser::UnresolvedOperand> operands;
SmallVector<Type> operandTypes;

if (parser.parseSuccessor(destination).failed())
return failure();

if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseOperandList(
operands, OpAsmParser::Delimiter::None)) ||
failed(parser.parseColonTypeList(operandTypes)) ||
failed(parser.parseRParen()))
return failure();
}
succOperandBlocks.push_back(destination);
succOperands.emplace_back(operands);
succOperandsTypes.emplace_back(operandTypes);
return success();
},
"successor blocks")))
return failure();
return success();
}

static void
printIndirectBrOpSucessors(OpAsmPrinter &p, IndirectBrOp op, Type flagType,
SuccessorRange succs, OperandRangeRange succOperands,
const TypeRangeRange &succOperandsTypes) {
p << "[";
llvm::interleave(
llvm::zip(succs, succOperands),
[&](auto i) {
p.printNewline();
p.printSuccessorAndUseList(std::get<0>(i), std::get<1>(i));
},
[&] { p << ','; });
if (!succOperands.empty())
p.printNewline();
p << "]";
}

//===----------------------------------------------------------------------===//
// AssumeOp (intrinsic)
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,15 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
moduleTranslation.mapBranch(&opInst, switchInst);
return success();
}
if (auto indBrOp = dyn_cast<LLVM::IndirectBrOp>(opInst)) {
llvm::IndirectBrInst *indBr = builder.CreateIndirectBr(
moduleTranslation.lookupValue(indBrOp.getAddr()),
indBrOp->getNumSuccessors());
for (auto *succ : indBrOp.getSuccessors())
indBr->addDestination(moduleTranslation.lookupBlock(succ));
moduleTranslation.mapBranch(&opInst, indBr);
return success();
}

// Emit addressof. We need to look up the global value referenced by the
// operation and store it in the MLIR-to-LLVM value mapping. This does not
Expand Down
33 changes: 31 additions & 2 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1988,6 +1988,35 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
return success();
}

if (inst->getOpcode() == llvm::Instruction::IndirectBr) {
auto *indBrInst = cast<llvm::IndirectBrInst>(inst);

FailureOr<Value> basePtr = convertValue(indBrInst->getAddress());
if (failed(basePtr))
return failure();

SmallVector<Block *> succBlocks;
// `succBlockArgs` is storage for the block arguments ranges used in
// `succBlockArgsRange`, so the later references live data.
SmallVector<SmallVector<Value>> succBlockArgs;
SmallVector<ValueRange> succBlockArgsRange;
for (auto i : llvm::seq<unsigned>(0, indBrInst->getNumSuccessors())) {
llvm::BasicBlock *succ = indBrInst->getSuccessor(i);
SmallVector<Value> blockArgs;
if (failed(convertBranchArgs(indBrInst, succ, blockArgs)))
return failure();
succBlocks.push_back(lookupBlock(succ));
succBlockArgs.push_back(blockArgs);
succBlockArgsRange.push_back(succBlockArgs.back());
}
Location loc = translateLoc(inst->getDebugLoc());
auto indBrOp = builder.create<LLVM::IndirectBrOp>(
loc, *basePtr, succBlockArgsRange, succBlocks);

mapNoResultOp(inst, indBrOp);
return success();
}

// Convert all instructions that have an mlirBuilder.
if (succeeded(convertInstructionImpl(builder, inst, *this, iface)))
return success();
Expand All @@ -1998,8 +2027,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
LogicalResult ModuleImport::processInstruction(llvm::Instruction *inst) {
// FIXME: Support uses of SubtargetData.
// FIXME: Add support for call / operand attributes.
// FIXME: Add support for the indirectbr, cleanupret, catchret, catchswitch,
// callbr, vaarg, catchpad, cleanuppad instructions.
// FIXME: Add support for the cleanupret, catchret, catchswitch, callbr,
// vaarg, catchpad, cleanuppad instructions.

// Convert LLVM intrinsics calls to MLIR intrinsics.
if (auto *intrinsic = dyn_cast<llvm::IntrinsicInst>(inst))
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,14 @@ static Value getPHISourceValue(Block *current, Block *pred,
return switchOp.getCaseOperands(i.index())[index];
}

if (auto indBrOp = dyn_cast<LLVM::IndirectBrOp>(terminator)) {
// For indirect branches we take operands for each successor.
for (const auto &i : llvm::enumerate(indBrOp->getSuccessors())) {
if (indBrOp->getSuccessor(i.index()) == current)
return indBrOp.getSuccessorOperands(i.index())[index];
}
}

if (auto invokeOp = dyn_cast<LLVM::InvokeOp>(terminator)) {
return invokeOp.getNormalDest() == current
? invokeOp.getNormalDestOperands()[index]
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/LLVMIR/blockaddress-canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(llvm.func(canonicalize{region-simplify=aggressive}))' -split-input-file | FileCheck %s
// RUN: mlir-opt %s -pass-pipeline='builtin.module(llvm.func(canonicalize{region-simplify=aggressive}))' -verify-diagnostics -split-input-file | FileCheck %s

llvm.mlir.global private @x() {addr_space = 0 : i32, dso_local} : !llvm.ptr {
%0 = llvm.blockaddress <function = @ba, tag = <id = 2>> : !llvm.ptr
Expand Down
83 changes: 83 additions & 0 deletions mlir/test/Dialect/LLVMIR/indirectbr.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// RUN: mlir-opt -split-input-file --verify-roundtrip %s | FileCheck %s

llvm.func @ib0(%dest : !llvm.ptr, %arg0 : i32, %arg1 : i32) -> i32 {
llvm.indirectbr %dest : !llvm.ptr, [
^head(%arg0 : i32),
^tail(%arg1, %arg0 : i32, i32)
]
^head(%r0 : i32):
llvm.return %r0 : i32
^tail(%r1 : i32, %r2 : i32):
%r = llvm.add %r1, %r2 : i32
llvm.return %r : i32
}

// CHECK: llvm.func @ib0(%[[Addr:.*]]: !llvm.ptr, %[[A0:.*]]: i32, %[[A1:.*]]: i32) -> i32 {
// CHECK: llvm.indirectbr %[[Addr]] : !llvm.ptr, [
// CHECK: ^bb1(%[[A0:.*]] : i32)
// CHECK: ^bb2(%[[A1:.*]], %[[A0]] : i32, i32)
// CHECK: ]
// CHECK: ^bb1(%[[Op0:.*]]: i32):
// CHECK: llvm.return %[[Op0]] : i32
// CHECK: ^bb2(%[[Op1:.*]]: i32, %[[Op2:.*]]: i32):
// CHECK: %[[Op3:.*]] = llvm.add %[[Op1]], %[[Op2]] : i32
// CHECK: llvm.return %[[Op3]] : i32
// CHECK: }

// -----

llvm.func @ib1(%dest : !llvm.ptr) {
llvm.indirectbr %dest : !llvm.ptr, []
}

// CHECK: llvm.func @ib1(%[[Addr:.*]]: !llvm.ptr) {
// CHECK: llvm.indirectbr %[[Addr]] : !llvm.ptr, []
// CHECK: }

// -----

// CHECK: llvm.func @test_indirectbr_phi(
// CHECK-SAME: %arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: i32) -> i32 {
llvm.func @callee(!llvm.ptr, i32, i32) -> i32
llvm.func @test_indirectbr_phi(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: i32) -> i32 {
%0 = llvm.mlir.undef : i1
%1 = llvm.mlir.addressof @test_indirectbr_phi : !llvm.ptr
%2 = llvm.blockaddress <function = @test_indirectbr_phi, tag = <id = 1>> : !llvm.ptr
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : i32
%3 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : i32) : i32
%4 = llvm.mlir.constant(2 : i32) : i32
%5 = llvm.select %0, %2, %arg0 : i1, !llvm.ptr
// CHECK: llvm.indirectbr {{.*}} : !llvm.ptr, [
// CHECK: ^[[HEAD_BB:.*]],
// CHECK: ^[[TAIL_BB:.*]](%[[ONE]] : i32)
// CHECK: ]
llvm.indirectbr %5 : !llvm.ptr, [
^bb1,
^bb2(%3 : i32)
]
^bb1:
// CHECK: ^[[HEAD_BB]]:
// CHECK: llvm.indirectbr {{.*}} : !llvm.ptr, [
// CHECK: ^[[TAIL_BB]](%[[TWO]] : i32),
// CHECK: ^[[END_BB:.*]]
// CHECK: ]
%6 = llvm.select %0, %2, %arg0 : i1, !llvm.ptr
llvm.indirectbr %6 : !llvm.ptr, [
^bb2(%4 : i32),
^bb3
]
^bb2(%7: i32):
// CHECK: ^[[TAIL_BB]](%[[BLOCK_ARG:.*]]: i32):
// CHECK: {{.*}} = llvm.call @callee({{.*}}, %[[BLOCK_ARG]])
// CHECK: llvm.return
%8 = llvm.call @callee(%arg1, %arg3, %7) : (!llvm.ptr, i32, i32) -> i32
llvm.return %8 : i32
^bb3:
// CHECK: ^[[END_BB]]:
// CHECK: llvm.blocktag
// CHECK: llvm.return
// CHECK: }
llvm.blocktag <id = 1>
llvm.return %arg3 : i32
}
12 changes: 0 additions & 12 deletions mlir/test/Target/LLVMIR/Import/import-failure.ll
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
; RUN: not mlir-translate -import-llvm -emit-expensive-warnings -split-input-file %s 2>&1 -o /dev/null | FileCheck %s

; CHECK: <unknown>
; CHECK-SAME: error: unhandled instruction: indirectbr ptr %dst, [label %bb1, label %bb2]
define i32 @unhandled_instruction(ptr %dst) {
indirectbr ptr %dst, [label %bb1, label %bb2]
bb1:
ret i32 0
bb2:
ret i32 1
}

; // -----

; Check that debug intrinsics with an unsupported argument are dropped.

declare void @llvm.dbg.value(metadata, metadata, metadata)
Expand Down
Loading
Loading