Skip to content

Commit 471a612

Browse files
authored
[mlir][drr] Add warning for simple case of mismatched variadic. (#84040)
When a variadic argument is expected but not provided the compilation fails later with a difficult to follow compilation error. Add a simple check to catch one such case. This is not yet general as it doesn't yet check leaf nodes.
1 parent 52d5b8e commit 471a612

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

mlir/test/mlir-tblgen/rewriter-errors.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR5 %s 2>&1 | FileCheck --check-prefix=ERROR5 %s
66
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR6 %s 2>&1 | FileCheck --check-prefix=ERROR6 %s
77
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR7 %s 2>&1 | FileCheck --check-prefix=ERROR7 %s
8+
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR8 %s 2>&1 | FileCheck --check-prefix=ERROR8 %s
89

910
include "mlir/IR/OpBase.td"
1011
include "mlir/IR/PatternBase.td"
@@ -64,3 +65,17 @@ def : Pat<(OpB:$result $val, $attr), (OpA $val, $val), [(AnyInteger:$result)]>;
6465
// ERROR7: [[@LINE+1]]:1: error: type constraint requires exactly one argument
6566
def : Pat<(OpB:$opB $val, $attr), (OpA $val, $val), [(AnyInteger $opB, $val)]>;
6667
#endif
68+
69+
def OpC : A_Op<"op_c">, Results<(outs AnyInteger)>;
70+
def OpD : A_Op<"op_d">, Arguments<(ins Variadic<AnyInteger>:$vargs)>, Results<(outs AnyInteger)>;
71+
72+
#ifdef ERROR8
73+
// Check that op with variadic operand gets variadic operand in target,
74+
//
75+
// FIXME: this should be an error.
76+
def : Pat<(OpB:$opB $val, $attr), (OpD $val)>;
77+
78+
// ERROR8: [[@LINE+2]]
79+
// ERROR8-SAME: op expects variadic operand `vargs`, while provided is non-variadic
80+
def : Pat<(OpB:$opB $val, $attr), (OpD (OpC))>;
81+
#endif

mlir/tools/mlir-tblgen/RewriterGen.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Support/IndentedOstream.h"
14+
#include "mlir/TableGen/Argument.h"
1415
#include "mlir/TableGen/Attribute.h"
1516
#include "mlir/TableGen/CodeGenHelpers.h"
1617
#include "mlir/TableGen/Format.h"
1718
#include "mlir/TableGen/GenInfo.h"
1819
#include "mlir/TableGen/Operator.h"
1920
#include "mlir/TableGen/Pattern.h"
2021
#include "mlir/TableGen/Predicate.h"
22+
#include "mlir/TableGen/Property.h"
2123
#include "mlir/TableGen/Type.h"
2224
#include "llvm/ADT/FunctionExtras.h"
2325
#include "llvm/ADT/SetVector.h"
@@ -1518,10 +1520,36 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
15181520
// the key. This includes both bound and unbound child nodes.
15191521
ChildNodeIndexNameMap childNodeNames;
15201522

1523+
// If the argument is a type constraint, then its an operand. Check if the
1524+
// op's argument is variadic that the argument in the pattern is too.
1525+
auto checkIfMatchedVariadic = [&](int i) {
1526+
// FIXME: This does not yet check for variable/leaf case.
1527+
// FIXME: Change so that native code call can be handled.
1528+
const auto *operand =
1529+
llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(i));
1530+
if (!operand || !operand->isVariadic())
1531+
return;
1532+
1533+
auto child = tree.getArgAsNestedDag(i);
1534+
if (!child)
1535+
return;
1536+
1537+
// Skip over replaceWithValues.
1538+
while (child.isReplaceWithValue()) {
1539+
if (!(child = child.getArgAsNestedDag(0)))
1540+
return;
1541+
}
1542+
if (!child.isNativeCodeCall() && !child.isVariadic())
1543+
PrintFatalError(loc, formatv("op expects variadic operand `{0}`, while "
1544+
"provided is non-variadic",
1545+
resultOp.getArgName(i)));
1546+
};
1547+
15211548
// First go through all the child nodes who are nested DAG constructs to
15221549
// create ops for them and remember the symbol names for them, so that we can
15231550
// use the results in the current node. This happens in a recursive manner.
15241551
for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1552+
checkIfMatchedVariadic(i);
15251553
if (auto child = tree.getArgAsNestedDag(i))
15261554
childNodeNames[i] = handleResultPattern(child, i, depth + 1);
15271555
}

0 commit comments

Comments
 (0)