Skip to content

Commit f3e5594

Browse files
authored
[mlir][ODS] Switch declarative rewrite rules to properties structs (#124876)
Now that we have collective builders that take `const [RelevantOp]::Properties &` arguments, we don't need to serialize all the attributes that'll be set during an output pattern into a dictionary attribute. Similarly, we can use the properties struct to get the attributes instead of needing to go through the big if statement in getAttrOfType<>(). This also enables us to have declarative rewrite rules that match non-attribute properties in a future PR. This commit also adds a basic test for the generated matchers since there didn't seem to already be one.
1 parent b334321 commit f3e5594

File tree

2 files changed

+97
-23
lines changed

2 files changed

+97
-23
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
2+
3+
include "mlir/IR/OpBase.td"
4+
include "mlir/IR/PatternBase.td"
5+
6+
def Test_Dialect : Dialect {
7+
let name = "test";
8+
}
9+
class NS_Op<string mnemonic, list<Trait> traits> :
10+
Op<Test_Dialect, mnemonic, traits>;
11+
12+
def AOp : NS_Op<"a_op", []> {
13+
let arguments = (ins
14+
I32:$x,
15+
I32Attr:$y
16+
);
17+
18+
let results = (outs I32:$z);
19+
}
20+
21+
def BOp : NS_Op<"b_op", []> {
22+
let arguments = (ins
23+
I32Attr:$y
24+
);
25+
26+
let results = (outs I32:$z);
27+
}
28+
29+
def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>;
30+
// CHECK-LABEL: struct test1
31+
// CHECK: ::llvm::LogicalResult matchAndRewrite
32+
// CHECK-DAG: ::mlir::IntegerAttr y;
33+
// CHECK-DAG: test::BOp x;
34+
// CHECK-DAG: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
35+
// CHECK: tblgen_ops.push_back(op0);
36+
// CHECK: x = castedOp1;
37+
// CHECK: tblgen_attr = castedOp1.getProperties().getY();
38+
// CHECK: if (!(tblgen_attr))
39+
// CHECK: y = tblgen_attr;
40+
// CHECK: tblgen_ops.push_back(op1);
41+
42+
// CHECK: test::AOp tblgen_AOp_0;
43+
// CHECK: ::llvm::SmallVector<::mlir::Value, 4> tblgen_values;
44+
// CHECK: test::AOp::Properties tblgen_props;
45+
// CHECK: tblgen_values.push_back((*x.getODSResults(0).begin()));
46+
// CHECK: tblgen_props.y = ::llvm::dyn_cast_if_present<decltype(tblgen_props.y)>(y);
47+
// CHECK: tblgen_AOp_0 = rewriter.create<test::AOp>(odsLoc, tblgen_types, tblgen_values, tblgen_props);

mlir/tools/mlir-tblgen/RewriterGen.cpp

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class PatternEmitter {
122122

123123
// Emits C++ statements for matching the `argIndex`-th argument of the given
124124
// DAG `tree` as an attribute.
125-
void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
125+
void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex,
126126
int depth);
127127

128128
// Emits C++ for checking a match with a corresponding match failure
@@ -664,7 +664,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
664664
/*variadicSubIndex=*/std::nullopt);
665665
++nextOperand;
666666
} else if (isa<NamedAttribute *>(opArg)) {
667-
emitAttributeMatch(tree, opName, opArgIdx, depth);
667+
emitAttributeMatch(tree, castedName, opArgIdx, depth);
668668
} else {
669669
PrintFatalError(loc, "unhandled case when matching op");
670670
}
@@ -864,16 +864,22 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
864864
os.unindent() << "}\n";
865865
}
866866

867-
void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
867+
void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName,
868868
int argIndex, int depth) {
869869
Operator &op = tree.getDialectOp(opMap);
870870
auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex));
871871
const auto &attr = namedAttr->attr;
872872

873873
os << "{\n";
874-
os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
875-
"(void)tblgen_attr;\n",
876-
opName, attr.getStorageType(), namedAttr->name);
874+
if (op.getDialect().usePropertiesForAttributes()) {
875+
os.indent() << formatv("auto tblgen_attr = {0}.getProperties().{1}();\n",
876+
castedName, op.getGetterName(namedAttr->name));
877+
} else {
878+
os.indent() << formatv(
879+
"auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
880+
"(void)tblgen_attr;\n",
881+
castedName, attr.getStorageType(), namedAttr->name);
882+
}
877883

878884
// TODO: This should use getter method to avoid duplication.
879885
if (attr.hasDefaultValue()) {
@@ -887,7 +893,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
887893
// That is precisely what getDiscardableAttr() returns on missing
888894
// attributes.
889895
} else {
890-
emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
896+
emitMatchCheck(castedName, tgfmt("tblgen_attr", &fmtCtx),
891897
formatv("\"expected op '{0}' to have attribute '{1}' "
892898
"of type '{2}'\"",
893899
op.getOperationName(), namedAttr->name,
@@ -918,7 +924,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
918924
}
919925
}
920926
emitStaticVerifierCall(
921-
verifier, opName, "tblgen_attr",
927+
verifier, castedName, "tblgen_attr",
922928
formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
923929
"'{2}'\"",
924930
op.getOperationName(), namedAttr->name,
@@ -1532,6 +1538,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
15321538
LLVM_DEBUG(llvm::dbgs() << '\n');
15331539

15341540
Operator &resultOp = tree.getDialectOp(opMap);
1541+
bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
15351542
auto numOpArgs = resultOp.getNumArgs();
15361543
auto numPatArgs = tree.getNumArgs();
15371544

@@ -1623,9 +1630,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
16231630
createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
16241631

16251632
// Then create the op.
1626-
os.scope("", "\n}\n").os << formatv(
1627-
"{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
1628-
valuePackName, resultOp.getQualCppClassName(), locToUse);
1633+
os.scope("", "\n}\n").os
1634+
<< formatv("{0} = rewriter.create<{1}>({2}, tblgen_values, {3});",
1635+
valuePackName, resultOp.getQualCppClassName(), locToUse,
1636+
useProperties ? "tblgen_props" : "tblgen_attrs");
16291637
return resultValue;
16301638
}
16311639

@@ -1682,8 +1690,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
16821690
}
16831691
}
16841692
os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
1685-
"tblgen_values, tblgen_attrs);\n",
1686-
valuePackName, resultOp.getQualCppClassName(), locToUse);
1693+
"tblgen_values, {3});\n",
1694+
valuePackName, resultOp.getQualCppClassName(), locToUse,
1695+
useProperties ? "tblgen_props" : "tblgen_attrs");
16871696
os.unindent() << "}\n";
16881697
return resultValue;
16891698
}
@@ -1791,16 +1800,27 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
17911800
DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
17921801
Operator &resultOp = node.getDialectOp(opMap);
17931802

1803+
bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
17941804
auto scope = os.scope();
17951805
os << formatv("::llvm::SmallVector<::mlir::Value, 4> "
17961806
"tblgen_values; (void)tblgen_values;\n");
1797-
os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
1798-
"tblgen_attrs; (void)tblgen_attrs;\n");
1807+
if (useProperties) {
1808+
os << formatv("{0}::Properties tblgen_props; (void)tblgen_props;\n",
1809+
resultOp.getQualCppClassName());
1810+
} else {
1811+
os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
1812+
"tblgen_attrs; (void)tblgen_attrs;\n");
1813+
}
17991814

1815+
const char *setPropCmd =
1816+
"tblgen_props.{0} = "
1817+
"::llvm::dyn_cast_if_present<decltype(tblgen_props.{0})>({1});\n";
18001818
const char *addAttrCmd =
18011819
"if (auto tmpAttr = {1}) {\n"
18021820
" tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
18031821
"tmpAttr);\n}\n";
1822+
const char *setterCmd = (useProperties) ? setPropCmd : addAttrCmd;
1823+
18041824
int numVariadic = 0;
18051825
bool hasOperandSegmentSizes = false;
18061826
std::vector<std::string> sizes;
@@ -1814,13 +1834,13 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
18141834
if (!subTree.isNativeCodeCall())
18151835
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
18161836
"for creating attribute");
1817-
os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
1837+
1838+
os << formatv(setterCmd, opArgName, childNodeNames.lookup(argIndex));
18181839
} else {
18191840
auto leaf = node.getArgAsLeaf(argIndex);
18201841
// The argument in the result DAG pattern.
18211842
auto patArgName = node.getArgName(argIndex);
1822-
os << formatv(addAttrCmd, opArgName,
1823-
handleOpArgument(leaf, patArgName));
1843+
os << formatv(setterCmd, opArgName, handleOpArgument(leaf, patArgName));
18241844
}
18251845
continue;
18261846
}
@@ -1876,11 +1896,18 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
18761896
const auto *sameVariadicSize =
18771897
resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
18781898
if (!sameVariadicSize) {
1879-
const char *setSizes = R"(
1880-
tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
1881-
rewriter.getDenseI32ArrayAttr({{ {0} }));
1882-
)";
1883-
os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
1899+
if (useProperties) {
1900+
const char *setSizes = R"(
1901+
tblgen_props.operandSegmentSizes = {{ {0} };
1902+
)";
1903+
os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
1904+
} else {
1905+
const char *setSizes = R"(
1906+
tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
1907+
rewriter.getDenseI32ArrayAttr({{ {0} }));
1908+
)";
1909+
os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
1910+
}
18841911
}
18851912
}
18861913
}

0 commit comments

Comments
 (0)