Skip to content

Commit b1de971

Browse files
committed
[mlir][ODS] Add support for specifying the successors of an operation.
This revision add support in ODS for specifying the successors of an operation. Successors are specified via the `successors` list: ``` let successors = (successor AnySuccessor:$target, AnySuccessor:$otherTarget); ``` Differential Revision: https://reviews.llvm.org/D74783
1 parent 93813e5 commit b1de971

File tree

18 files changed

+361
-65
lines changed

18 files changed

+361
-65
lines changed

mlir/docs/OpDefinitions.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,24 @@ Similar to variadic operands, `Variadic<...>` can also be used for results.
279279
And similarly, `SameVariadicResultSize` for multiple variadic results in the
280280
same operation.
281281

282+
### Operation successors
283+
284+
For terminator operations, the successors are specified inside of the
285+
`dag`-typed `successors`, led by `successor`:
286+
287+
```tablegen
288+
let successors = (successor
289+
<successor-constraint>:$<successor-name>,
290+
...
291+
);
292+
```
293+
294+
#### Variadic successors
295+
296+
Similar to the `Variadic` class used for variadic operands and results,
297+
`VariadicSuccessor<...>` can be used for successors. Variadic successors can
298+
currently only be specified as the last successor in the successor list.
299+
282300
### Operation traits and constraints
283301

284302
Traits are operation properties that affect syntax or semantics. MLIR C++

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,25 @@ def LLVMInt : TypeConstraint<
3131
CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>,
3232
"LLVM dialect integer">;
3333

34+
def LLVMIntBase : TypeConstraint<
35+
And<[LLVM_Type.predicate,
36+
CPred<"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy()">]>,
37+
"LLVM dialect integer">;
38+
39+
// Integer type of a specific width.
40+
class LLVMI<int width>
41+
: Type<And<[
42+
LLVM_Type.predicate,
43+
CPred<
44+
"$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy(" # width # ")">]>,
45+
"LLVM dialect " # width # "-bit integer">,
46+
BuildableType<
47+
"::mlir::LLVM::LLVMType::getIntNTy("
48+
"$_builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(),"
49+
# width # ")">;
50+
51+
def LLVMI1 : LLVMI<1>;
52+
3453
// Base class for LLVM operations. Defines the interface to the llvm::IRBuilder
3554
// used to translate to LLVM IR proper.
3655
class LLVM_OpBase<Dialect dialect, string mnemonic, list<OpTrait> traits = []> :

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ class LLVM_ZeroResultOp<string mnemonic, list<OpTrait> traits = []> :
7272
// Base class for LLVM terminator operations. All terminator operations have
7373
// zero results and an optional list of successors.
7474
class LLVM_TerminatorOp<string mnemonic, list<OpTrait> traits = []> :
75-
LLVM_Op<mnemonic, !listconcat(traits, [Terminator])>,
76-
Arguments<(ins Variadic<LLVM_Type>:$args)>, Results<(outs)> {
75+
LLVM_Op<mnemonic, !listconcat(traits, [Terminator])> {
7776
let builders = [
7877
OpBuilder<
7978
"Builder *, OperationState &result, "
@@ -320,15 +319,10 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [Terminator]>,
320319
Arguments<(ins OptionalAttr<FlatSymbolRefAttr>:$callee,
321320
Variadic<LLVM_Type>)>,
322321
Results<(outs Variadic<LLVM_Type>)> {
322+
let successors = (successor AnySuccessor:$normalDest,
323+
AnySuccessor:$unwindDest);
324+
323325
let builders = [OpBuilder<
324-
"Builder *b, OperationState &result, ArrayRef<Type> tys, "
325-
"FlatSymbolRefAttr callee, ValueRange ops, Block* normal, "
326-
"ValueRange normalOps, Block* unwind, ValueRange unwindOps",
327-
[{
328-
result.addAttribute("callee", callee);
329-
build(b, result, tys, ops, normal, normalOps, unwind, unwindOps);
330-
}]>,
331-
OpBuilder<
332326
"Builder *b, OperationState &result, ArrayRef<Type> tys, "
333327
"ValueRange ops, Block* normal, "
334328
"ValueRange normalOps, Block* unwind, ValueRange unwindOps",
@@ -460,19 +454,19 @@ def LLVM_SelectOp
460454

461455
// Terminators.
462456
def LLVM_BrOp : LLVM_TerminatorOp<"br", []> {
457+
let successors = (successor AnySuccessor:$dest);
463458
let parser = [{ return parseBrOp(parser, result); }];
464459
let printer = [{ printBrOp(p, *this); }];
465460
}
466461
def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> {
467-
let verifier = [{
468-
if (getNumSuccessors() != 2)
469-
return emitOpError("expected exactly two successors");
470-
return success();
471-
}];
462+
let arguments = (ins LLVMI1:$condition);
463+
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
464+
472465
let parser = [{ return parseCondBrOp(parser, result); }];
473466
let printer = [{ printCondBrOp(p, *this); }];
474467
}
475-
def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []> {
468+
def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []>,
469+
Arguments<(ins Variadic<LLVM_Type>:$args)> {
476470
string llvmBuilder = [{
477471
if ($_numOperands != 0)
478472
builder.CreateRet($args[0]);

mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
4141
```
4242
}];
4343

44-
let arguments = (ins
45-
Variadic<AnyType>:$block_arguments
46-
);
44+
let arguments = (ins);
4745

4846
let results = (outs);
4947

48+
let successors = (successor AnySuccessor:$target);
49+
50+
let verifier = [{ return success(); }];
51+
5052
let builders = [
5153
OpBuilder<
5254
"Builder *, OperationState &state, "
@@ -60,12 +62,10 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
6062

6163
let extraClassDeclaration = [{
6264
/// Returns the branch target block.
63-
Block *getTarget() { return getOperation()->getSuccessor(0); }
65+
Block *getTarget() { return target(); }
6466

6567
/// Returns the block arguments.
66-
operand_range getBlockArguments() {
67-
return getOperation()->getSuccessorOperands(0);
68-
}
68+
operand_range getBlockArguments() { return targetOperands(); }
6969
}];
7070

7171
let autogenSerialization = 0;
@@ -115,12 +115,14 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional",
115115

116116
let arguments = (ins
117117
SPV_Bool:$condition,
118-
Variadic<AnyType>:$branch_arguments,
119118
OptionalAttr<I32ArrayAttr>:$branch_weights
120119
);
121120

122121
let results = (outs);
123122

123+
let successors = (successor AnySuccessor:$trueTarget,
124+
AnySuccessor:$falseTarget);
125+
124126
let builders = [
125127
OpBuilder<
126128
"Builder *builder, OperationState &state, Value condition, "

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,10 @@ def BranchOp : Std_Op<"br", [Terminator]> {
232232
^bb3(%3: tensor<*xf32>):
233233
}];
234234

235-
let arguments = (ins Variadic<AnyType>:$operands);
235+
let successors = (successor AnySuccessor:$dest);
236236

237-
let builders = [OpBuilder<
238-
"Builder *, OperationState &result, Block *dest,"
239-
"ValueRange operands = {}", [{
240-
result.addSuccessor(dest, operands);
237+
let builders = [OpBuilder<"Builder *, OperationState &result, Block *dest", [{
238+
result.addSuccessor(dest, llvm::None);
241239
}]>];
242240

243241
// BranchOp is fully verified by traits.
@@ -513,16 +511,8 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
513511
...
514512
}];
515513

516-
let arguments = (ins I1:$condition, Variadic<AnyType>:$branchOperands);
517-
518-
let builders = [OpBuilder<
519-
"Builder *, OperationState &result, Value condition,"
520-
"Block *trueDest, ValueRange trueOperands,"
521-
"Block *falseDest, ValueRange falseOperands", [{
522-
result.addOperands(condition);
523-
result.addSuccessor(trueDest, trueOperands);
524-
result.addSuccessor(falseDest, falseOperands);
525-
}]>];
514+
let arguments = (ins I1:$condition);
515+
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
526516

527517
// CondBranchOp is fully verified by traits.
528518
let verifier = ?;

mlir/include/mlir/IR/OpBase.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ class AttrConstraint<Pred predicate, string description = ""> :
185185
class RegionConstraint<Pred predicate, string description = ""> :
186186
Constraint<predicate, description>;
187187

188+
// Subclass for constraints on a successor.
189+
class SuccessorConstraint<Pred predicate, string description = ""> :
190+
Constraint<predicate, description>;
191+
188192
// How to use these constraint categories:
189193
//
190194
// * Use TypeConstraint to specify
@@ -1341,6 +1345,21 @@ class SizedRegion<int numBlocks> : Region<
13411345
CPred<"$_self.getBlocks().size() == " # numBlocks>,
13421346
"region with " # numBlocks # " blocks">;
13431347

1348+
//===----------------------------------------------------------------------===//
1349+
// Successor definitions
1350+
//===----------------------------------------------------------------------===//
1351+
1352+
class Successor<Pred condition, string descr = ""> :
1353+
SuccessorConstraint<condition, descr>;
1354+
1355+
// Any successor.
1356+
def AnySuccessor : Successor<?, "any successor">;
1357+
1358+
// A variadic successor constraint. It expands to zero or more of the base
1359+
// successor.
1360+
class VariadicSuccessor<Successor successor>
1361+
: Successor<successor.predicate, successor.description>;
1362+
13441363
//===----------------------------------------------------------------------===//
13451364
// OpTrait definitions
13461365
//===----------------------------------------------------------------------===//
@@ -1537,6 +1556,9 @@ def outs;
15371556
// Marker used to identify the region list for an op.
15381557
def region;
15391558

1559+
// Marker used to identify the successor list for an op.
1560+
def successor;
1561+
15401562
// Class for defining a custom builder.
15411563
//
15421564
// TableGen generates several generic builders for each op by default (see
@@ -1587,6 +1609,9 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
15871609
// The list of regions of the op. Default to 0 regions.
15881610
dag regions = (region);
15891611

1612+
// The list of successors of the op. Default to 0 successors.
1613+
dag successors = (successor);
1614+
15901615
// Attribute getters can be added to the op by adding an Attr member
15911616
// with the name and type of the attribute. E.g., adding int attribute
15921617
// with name "value" and type "i32":

mlir/include/mlir/TableGen/Constraint.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Constraint {
4848
StringRef getDescription() const;
4949

5050
// Constraint kind
51-
enum Kind { CK_Attr, CK_Region, CK_Type, CK_Uncategorized };
51+
enum Kind { CK_Attr, CK_Region, CK_Successor, CK_Type, CK_Uncategorized };
5252

5353
Kind getKind() const { return kind; }
5454

mlir/include/mlir/TableGen/Operator.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/TableGen/Dialect.h"
2020
#include "mlir/TableGen/OpTrait.h"
2121
#include "mlir/TableGen/Region.h"
22+
#include "mlir/TableGen/Successor.h"
2223
#include "mlir/TableGen/Type.h"
2324
#include "llvm/ADT/PointerUnion.h"
2425
#include "llvm/ADT/SmallVector.h"
@@ -138,6 +139,20 @@ class Operator {
138139
// Returns the `index`-th region.
139140
const NamedRegion &getRegion(unsigned index) const;
140141

142+
// Successors.
143+
using const_successor_iterator = const NamedSuccessor *;
144+
const_successor_iterator successor_begin() const;
145+
const_successor_iterator successor_end() const;
146+
llvm::iterator_range<const_successor_iterator> getSuccessors() const;
147+
148+
// Returns the number of successors.
149+
unsigned getNumSuccessors() const;
150+
// Returns the `index`-th successor.
151+
const NamedSuccessor &getSuccessor(unsigned index) const;
152+
153+
// Returns the number of variadic successors in this operation.
154+
unsigned getNumVariadicSuccessors() const;
155+
141156
// Trait.
142157
using const_trait_iterator = const OpTrait *;
143158
const_trait_iterator trait_begin() const;
@@ -193,6 +208,9 @@ class Operator {
193208
// The results of the op.
194209
SmallVector<NamedTypeConstraint, 4> results;
195210

211+
// The successors of this op.
212+
SmallVector<NamedSuccessor, 0> successors;
213+
196214
// The traits of the op.
197215
SmallVector<OpTrait, 4> traits;
198216

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//===- Successor.h - TableGen successor definitions -------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_TABLEGEN_SUCCESSOR_H_
10+
#define MLIR_TABLEGEN_SUCCESSOR_H_
11+
12+
#include "mlir/Support/LLVM.h"
13+
#include "mlir/TableGen/Constraint.h"
14+
15+
namespace mlir {
16+
namespace tblgen {
17+
18+
// Wrapper class providing helper methods for accessing Successor defined in
19+
// TableGen.
20+
class Successor : public Constraint {
21+
public:
22+
using Constraint::Constraint;
23+
24+
static bool classof(const Constraint *c) {
25+
return c->getKind() == CK_Successor;
26+
}
27+
28+
// Returns true if this successor is variadic.
29+
bool isVariadic() const;
30+
};
31+
32+
// A struct bundling a successor's constraint and its name.
33+
struct NamedSuccessor {
34+
// Returns true if this successor is variadic.
35+
bool isVariadic() const { return constraint.isVariadic(); }
36+
37+
StringRef name;
38+
Successor constraint;
39+
};
40+
41+
} // end namespace tblgen
42+
} // end namespace mlir
43+
44+
#endif // MLIR_TABLEGEN_SUCCESSOR_H_

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,15 +234,14 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
234234
static LogicalResult verify(InvokeOp op) {
235235
if (op.getNumResults() > 1)
236236
return op.emitOpError("must have 0 or 1 result");
237-
if (op.getNumSuccessors() != 2)
238-
return op.emitOpError("must have normal and unwind destinations");
239237

240-
if (op.getSuccessor(1)->empty())
238+
Block *unwindDest = op.unwindDest();
239+
if (unwindDest->empty())
241240
return op.emitError(
242241
"must have at least one operation in unwind destination");
243242

244243
// In unwind destination, first operation must be LandingpadOp
245-
if (!isa<LandingpadOp>(op.getSuccessor(1)->front()))
244+
if (!isa<LandingpadOp>(unwindDest->front()))
246245
return op.emitError("first operation in unwind destination should be a "
247246
"llvm.landingpad operation");
248247

mlir/lib/Dialect/SPIRV/SPIRVOps.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,14 +1036,6 @@ static void print(spirv::BranchOp branchOp, OpAsmPrinter &printer) {
10361036
printer.printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0);
10371037
}
10381038

1039-
static LogicalResult verify(spirv::BranchOp branchOp) {
1040-
auto *op = branchOp.getOperation();
1041-
if (op->getNumSuccessors() != 1)
1042-
branchOp.emitOpError("must have exactly one successor");
1043-
1044-
return success();
1045-
}
1046-
10471039
//===----------------------------------------------------------------------===//
10481040
// spv.BranchConditionalOp
10491041
//===----------------------------------------------------------------------===//
@@ -1114,10 +1106,6 @@ static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
11141106
}
11151107

11161108
static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
1117-
auto *op = branchOp.getOperation();
1118-
if (op->getNumSuccessors() != 2)
1119-
return branchOp.emitOpError("must have exactly two successors");
1120-
11211109
if (auto weights = branchOp.branch_weights()) {
11221110
if (weights->getValue().size() != 2) {
11231111
return branchOp.emitOpError("must have exactly two branch weights");

mlir/lib/TableGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_llvm_library(LLVMMLIRTableGen
1010
OpTrait.cpp
1111
Pattern.cpp
1212
Predicate.cpp
13+
Successor.cpp
1314
Type.cpp
1415

1516
ADDITIONAL_HEADER_DIRS

mlir/lib/TableGen/Constraint.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ Constraint::Constraint(const llvm::Record *record)
2323
kind = CK_Attr;
2424
} else if (record->isSubClassOf("RegionConstraint")) {
2525
kind = CK_Region;
26+
} else if (record->isSubClassOf("SuccessorConstraint")) {
27+
kind = CK_Successor;
2628
} else {
2729
assert(record->isSubClassOf("Constraint"));
2830
}

0 commit comments

Comments
 (0)