Skip to content

Commit 616c86a

Browse files
committed
[mlir][drr] Set operand segment in rewrite
This allows some basic variadic operands in rewrites. There were some workarounds employed (like "aliasing" the attribute). Couldn't find a way to do this directly with properties.
1 parent 3cac608 commit 616c86a

File tree

3 files changed

+57
-0
lines changed

3 files changed

+57
-0
lines changed

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2495,6 +2495,29 @@ def TestDefaultStrAttrHasValueOp : TEST_Op<"has_str_value"> {
24952495
def : Pat<(TestDefaultStrAttrNoValueOp $value),
24962496
(TestDefaultStrAttrHasValueOp ConstantStrAttr<StrAttr, "foo">)>;
24972497

2498+
//===----------------------------------------------------------------------===//
2499+
// Test Ops with variadics
2500+
//===----------------------------------------------------------------------===//
2501+
2502+
def TestVariadicRewriteSrcOp : TEST_Op<"variadic_rewrite_src_op", [AttrSizedOperandSegments]> {
2503+
let arguments = (ins
2504+
Variadic<AnyType>:$arg,
2505+
AnyType:$brg,
2506+
Variadic<AnyType>:$crg
2507+
);
2508+
}
2509+
2510+
def TestVariadicRewriteDstOp : TEST_Op<"variadic_rewrite_dst_op", [AttrSizedOperandSegments]> {
2511+
let arguments = (ins
2512+
AnyType:$brg,
2513+
Variadic<AnyType>:$crg,
2514+
Variadic<AnyType>:$arg
2515+
);
2516+
}
2517+
2518+
def : Pat<(TestVariadicRewriteSrcOp $arg, $brg, $crg),
2519+
(TestVariadicRewriteDstOp $brg, $crg, $arg)>;
2520+
24982521
//===----------------------------------------------------------------------===//
24992522
// Test Ops with Default-Valued Attributes and Differing Print Settings
25002523
//===----------------------------------------------------------------------===//

mlir/test/mlir-tblgen/pattern.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,3 +683,16 @@ func.func @testConstantStrAttr() -> () {
683683
test.no_str_value {value = "bar"}
684684
return
685685
}
686+
687+
//===----------------------------------------------------------------------===//
688+
// Test that patterns with variadics propagate sizes
689+
//===----------------------------------------------------------------------===//
690+
691+
func.func @testVariadic(%arg_0: i32, %arg_1: i32, %brg: i64,
692+
%crg_0: f32, %crg_1: f32, %crg_2: f32, %crg_3: f32) -> () {
693+
// CHECK: "test.variadic_rewrite_dst_op"(%arg2, %arg3, %arg4, %arg5, %arg6, %arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 4, 2>}> : (i64, f32, f32, f32, f32, i32, i32) -> ()
694+
"test.variadic_rewrite_src_op"(%arg_0, %arg_1, %brg,
695+
%crg_0, %crg_1, %crg_2, %crg_3) {operandSegmentSizes = array<i32: 2, 1, 4>} :
696+
(i32, i32, i64, f32, f32, f32, f32) -> ()
697+
return
698+
}

mlir/tools/mlir-tblgen/RewriterGen.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,10 +1743,15 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
17431743
"if (auto tmpAttr = {1}) {\n"
17441744
" tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
17451745
"tmpAttr);\n}\n";
1746+
int numVariadic = 0;
1747+
bool hasOperandSegmentSizes = false;
1748+
std::vector<std::string> sizes;
17461749
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
17471750
if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
17481751
// The argument in the op definition.
17491752
auto opArgName = resultOp.getArgName(argIndex);
1753+
hasOperandSegmentSizes =
1754+
hasOperandSegmentSizes || opArgName == "operandSegmentSizes";
17501755
if (auto subTree = node.getArgAsNestedDag(argIndex)) {
17511756
if (!subTree.isNativeCodeCall())
17521757
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
@@ -1766,6 +1771,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
17661771
resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
17671772
std::string varName;
17681773
if (operand->isVariadic()) {
1774+
++numVariadic;
17691775
std::string range;
17701776
if (node.isNestedDagArg(argIndex)) {
17711777
range = childNodeNames.lookup(argIndex);
@@ -1777,7 +1783,9 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
17771783
range = symbolInfoMap.getValueAndRangeUse(range);
17781784
os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
17791785
range);
1786+
sizes.push_back(formatv("static_cast<int32_t>({0}.size())", range));
17801787
} else {
1788+
sizes.push_back("1");
17811789
os << formatv("tblgen_values.push_back(");
17821790
if (node.isNestedDagArg(argIndex)) {
17831791
os << symbolInfoMap.getValueAndRangeUse(
@@ -1804,6 +1812,19 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
18041812
os << ");\n";
18051813
}
18061814
}
1815+
1816+
if (numVariadic > 1 && !hasOperandSegmentSizes) {
1817+
// Only set size if it can't be computed.
1818+
const auto *sameVariadicSize =
1819+
resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
1820+
if (!sameVariadicSize) {
1821+
const char *setSizes = R"(
1822+
tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
1823+
rewriter.getDenseI32ArrayAttr({{ {0} }));
1824+
)";
1825+
os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
1826+
}
1827+
}
18071828
}
18081829

18091830
StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,

0 commit comments

Comments
 (0)