Skip to content

Commit 2d99c81

Browse files
committed
[mlir-tblgen] Support either in Tablegen DRR.
Add a new directive `either` to specify the operands can be matched in either order Reviewed By: jpienaar, Mogball Differential Revision: https://reviews.llvm.org/D110666
1 parent 1b409df commit 2d99c81

File tree

7 files changed

+244
-50
lines changed

7 files changed

+244
-50
lines changed

mlir/docs/DeclarativeRewrites.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,23 @@ Explicitly-specified return types will take precedence over return types
774774
inferred from op traits or user-defined builders. The return types of values
775775
replacing root op results cannot be overridden.
776776

777+
### `either`
778+
779+
The `either` directive is used to specify the operands may be matched in either
780+
order.
781+
782+
```tablegen
783+
def : Pat<(TwoArgOp (either $firstArg, (AnOp $secondArg))),
784+
(...)>;
785+
```
786+
787+
The above pattern will accept either `"test.TwoArgOp"(%I32Arg, %AnOpArg)` and
788+
`"test.TwoArgOp"(%AnOpArg, %I32Arg)`.
789+
790+
Only operand is supported with `either` and note that an operation with
791+
`Commutative` trait doesn't imply that it'll have the same behavior than
792+
`either` while pattern matching.
793+
777794
## Debugging Tips
778795

779796
### Run `mlir-tblgen` to see the generated content

mlir/include/mlir/IR/OpBase.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2730,6 +2730,21 @@ def location;
27302730

27312731
def returnType;
27322732

2733+
// Directive used to specify the operands may be matched in either order. When
2734+
// two adjacents are marked with `either`, it'll try to match the operands in
2735+
// either ordering of constraints. Example:
2736+
//
2737+
// ```
2738+
// (TwoArgOp (either $firstArg, (AnOp $secondArg)))
2739+
// ```
2740+
// The above pattern will accept either `"test.TwoArgOp"(%I32Arg, %AnOpArg)` and
2741+
// `"test.TwoArgOp"(%AnOpArg, %I32Arg)`.
2742+
//
2743+
// Only operand is supported with `either` and note that an operation with
2744+
// `Commutative` trait doesn't imply that it'll have the same behavior than
2745+
// `either` while pattern matching.
2746+
def either;
2747+
27332748
//===----------------------------------------------------------------------===//
27342749
// Attribute and Type generation
27352750
//===----------------------------------------------------------------------===//

mlir/include/mlir/TableGen/Pattern.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ class DagNode {
186186
// Returns true if this DAG node is wrapping native code call.
187187
bool isNativeCodeCall() const;
188188

189+
// Returns whether this DAG is an `either` specifier.
190+
bool isEither() const;
191+
189192
// Returns true if this DAG node is an operation.
190193
bool isOperation() const;
191194

mlir/lib/TableGen/Pattern.cpp

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ bool DagNode::isNativeCodeCall() const {
113113

114114
bool DagNode::isOperation() const {
115115
return !isNativeCodeCall() && !isReplaceWithValue() &&
116-
!isLocationDirective() && !isReturnTypeDirective();
116+
!isLocationDirective() && !isReturnTypeDirective() && !isEither();
117117
}
118118

119119
llvm::StringRef DagNode::getNativeCodeTemplate() const {
@@ -142,7 +142,9 @@ Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
142142
}
143143

144144
int DagNode::getNumOps() const {
145-
int count = isReplaceWithValue() ? 0 : 1;
145+
// We want to get number of operations recursively involved in the DAG tree.
146+
// All other directives should be excluded.
147+
int count = isOperation() ? 1 : 0;
146148
for (int i = 0, e = getNumArgs(); i != e; ++i) {
147149
if (auto child = getArgAsNestedDag(i))
148150
count += child.getNumOps();
@@ -184,6 +186,11 @@ bool DagNode::isReturnTypeDirective() const {
184186
return dagOpDef->getName() == "returnType";
185187
}
186188

189+
bool DagNode::isEither() const {
190+
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
191+
return dagOpDef->getName() == "either";
192+
}
193+
187194
void DagNode::print(raw_ostream &os) const {
188195
if (node)
189196
node->print(os);
@@ -764,22 +771,25 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
764771
if (tree.isOperation()) {
765772
auto &op = getDialectOp(tree);
766773
auto numOpArgs = op.getNumArgs();
774+
int numEither = 0;
767775

768-
// The pattern might have trailing directives.
776+
// We need to exclude the trailing directives and `either` directive groups
777+
// two operands of the operation.
769778
int numDirectives = 0;
770779
for (int i = numTreeArgs - 1; i >= 0; --i) {
771780
if (auto dagArg = tree.getArgAsNestedDag(i)) {
772781
if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
773782
++numDirectives;
774-
else
775-
break;
783+
else if (dagArg.isEither())
784+
++numEither;
776785
}
777786
}
778787

779-
if (numOpArgs != numTreeArgs - numDirectives) {
780-
auto err = formatv("op '{0}' argument number mismatch: "
781-
"{1} in pattern vs. {2} in definition",
782-
op.getOperationName(), numTreeArgs, numOpArgs);
788+
if (numOpArgs != numTreeArgs - numDirectives + numEither) {
789+
auto err =
790+
formatv("op '{0}' argument number mismatch: "
791+
"{1} in pattern vs. {2} in definition",
792+
op.getOperationName(), numTreeArgs + numEither, numOpArgs);
783793
PrintFatalError(&def, err);
784794
}
785795

@@ -791,10 +801,30 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
791801
verifyBind(infoMap.bindOpResult(treeName, op), treeName);
792802
}
793803

794-
for (int i = 0; i != numTreeArgs; ++i) {
804+
// The operand in `either` DAG should be bound to the operation in the
805+
// parent DagNode.
806+
auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
807+
int &opArgIdx) {
808+
for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
809+
if (DagNode subTree = tree.getArgAsNestedDag(i)) {
810+
collectBoundSymbols(subTree, infoMap, isSrcPattern);
811+
} else {
812+
auto argName = tree.getArgName(i);
813+
if (!argName.empty() && argName != "_")
814+
verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
815+
argName);
816+
}
817+
}
818+
};
819+
820+
for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
795821
if (auto treeArg = tree.getArgAsNestedDag(i)) {
796-
// This DAG node argument is a DAG node itself. Go inside recursively.
797-
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
822+
if (treeArg.isEither()) {
823+
collectSymbolInEither(tree, treeArg, opArgIdx);
824+
} else {
825+
// This DAG node argument is a DAG node itself. Go inside recursively.
826+
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
827+
}
798828
continue;
799829
}
800830

@@ -806,7 +836,7 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
806836
if (!treeArgName.empty() && treeArgName != "_") {
807837
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
808838
<< treeArgName << '\n');
809-
verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i),
839+
verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
810840
treeArgName);
811841
}
812842
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,6 +1328,30 @@ def : Pat<(OneI32ResultOp),
13281328
(replaceWithValue $results__2),
13291329
ConstantAttr<I32Attr, "2">)>;
13301330

1331+
//===----------------------------------------------------------------------===//
1332+
// Test Patterns (either)
1333+
1334+
def TestEitherOpA : TEST_Op<"either_op_a"> {
1335+
let arguments = (ins AnyInteger:$arg0, AnyInteger:$arg1, AnyInteger:$arg2);
1336+
let results = (outs I32:$output);
1337+
}
1338+
1339+
def TestEitherOpB : TEST_Op<"either_op_b"> {
1340+
let arguments = (ins AnyInteger:$arg0);
1341+
let results = (outs I32:$output);
1342+
}
1343+
1344+
def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $_),
1345+
(TestEitherOpB $arg2)>;
1346+
1347+
def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1), I16:$arg2), $_),
1348+
(TestEitherOpB $arg2)>;
1349+
1350+
def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1),
1351+
(TestEitherOpB I16:$arg2)),
1352+
$_),
1353+
(TestEitherOpB $arg2)>;
1354+
13311355
//===----------------------------------------------------------------------===//
13321356
// Test Patterns (Location)
13331357

mlir/test/mlir-tblgen/pattern.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,40 @@ func @redundantTest(%arg0: i32) -> i32 {
531531
return %0 : i32
532532
}
533533

534+
//===----------------------------------------------------------------------===//
535+
// Test either directive
536+
//===----------------------------------------------------------------------===//
537+
538+
// CHECK: @either_dag_leaf_only
539+
func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
540+
// CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
541+
%0 = "test.either_op_a"(%arg0, %arg1, %arg2) : (i32, i16, i8) -> i32
542+
// CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
543+
%1 = "test.either_op_a"(%arg1, %arg0, %arg2) : (i16, i32, i8) -> i32
544+
return
545+
}
546+
547+
// CHECK: @either_dag_leaf_dag_node
548+
func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
549+
%0 = "test.either_op_b"(%arg0) : (i32) -> i32
550+
// CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
551+
%1 = "test.either_op_a"(%0, %arg1, %arg2) : (i32, i16, i8) -> i32
552+
// CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
553+
%2 = "test.either_op_a"(%arg1, %0, %arg2) : (i16, i32, i8) -> i32
554+
return
555+
}
556+
557+
// CHECK: @either_dag_node_dag_node
558+
func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
559+
%0 = "test.either_op_b"(%arg0) : (i32) -> i32
560+
%1 = "test.either_op_b"(%arg1) : (i16) -> i32
561+
// CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
562+
%2 = "test.either_op_a"(%0, %1, %arg2) : (i32, i32, i8) -> i32
563+
// CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
564+
%3 = "test.either_op_a"(%1, %0, %arg2) : (i32, i32, i8) -> i32
565+
return
566+
}
567+
534568
//===----------------------------------------------------------------------===//
535569
// Test that ops without type deduction can be created with type builders.
536570
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)