-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ODS] Switch declarative rewrite rules to properties structs #124876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ODS] Switch declarative rewrite rules to properties structs #124876
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Krzysztof Drewniak (krzysz00) ChangesNow that we have collective builders that take This also enables us to have declarative rewrite rules that match non-attribute properties in a future PR. This commit also adds a basic test for the generated matchers since there didn't seem to already be one. Full diff: https://github.com/llvm/llvm-project/pull/124876.diff 2 Files Affected:
diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
new file mode 100644
index 00000000000000..77869d36cc12ee
--- /dev/null
+++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
@@ -0,0 +1,47 @@
+// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/PatternBase.td"
+
+def Test_Dialect : Dialect {
+ let name = "test";
+}
+class NS_Op<string mnemonic, list<Trait> traits> :
+ Op<Test_Dialect, mnemonic, traits>;
+
+def AOp : NS_Op<"a_op", []> {
+ let arguments = (ins
+ I32:$x,
+ I32Attr:$y
+ );
+
+ let results = (outs I32:$z);
+}
+
+def BOp : NS_Op<"b_op", []> {
+ let arguments = (ins
+ I32Attr:$y
+ );
+
+ let results = (outs I32:$z);
+}
+
+def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>;
+// CHECK-LABEL: struct test1
+// CHECK: ::llvm::LogicalResult matchAndRewrite
+// CHECK: ::mlir::IntegerAttr y;
+// CHECK: test::BOp x;
+// CHECK: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
+// CHECK: tblgen_ops.push_back(op0);
+// CHECK: x = castedOp1;
+// CHECK: tblgen_attr = castedOp1.getProperties().getY();
+// CHECK: if (!(tblgen_attr))
+// CHECK: y = tblgen_attr;
+// CHECK: tblgen_ops.push_back(op1);
+
+// CHECK: test::AOp tblgen_AOp_0;
+// CHECK: ::llvm::SmallVector<::mlir::Value, 4> tblgen_values;
+// CHECK: test::AOp::Properties tblgen_props;
+// CHECK: tblgen_values.push_back((*x.getODSResults(0).begin()));
+// CHECK: tblgen_props.y = ::llvm::dyn_cast_if_present<decltype(tblgen_props.y)>(y);
+// CHECK: tblgen_AOp_0 = rewriter.create<test::AOp>(odsLoc, tblgen_types, tblgen_values, tblgen_props);
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index a041c4d3277798..9d8d20798dc8db 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -122,7 +122,7 @@ class PatternEmitter {
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an attribute.
- void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
+ void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex,
int depth);
// Emits C++ for checking a match with a corresponding match failure
@@ -664,7 +664,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
/*variadicSubIndex=*/std::nullopt);
++nextOperand;
} else if (isa<NamedAttribute *>(opArg)) {
- emitAttributeMatch(tree, opName, opArgIdx, depth);
+ emitAttributeMatch(tree, castedName, opArgIdx, depth);
} else {
PrintFatalError(loc, "unhandled case when matching op");
}
@@ -864,16 +864,22 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
os.unindent() << "}\n";
}
-void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
+void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName,
int argIndex, int depth) {
Operator &op = tree.getDialectOp(opMap);
auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex));
const auto &attr = namedAttr->attr;
os << "{\n";
- os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
- "(void)tblgen_attr;\n",
- opName, attr.getStorageType(), namedAttr->name);
+ if (op.getDialect().usePropertiesForAttributes()) {
+ os.indent() << formatv("auto tblgen_attr = {0}.getProperties().{1}();\n",
+ castedName, op.getGetterName(namedAttr->name));
+ } else {
+ os.indent() << formatv(
+ "auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
+ "(void)tblgen_attr;\n",
+ castedName, attr.getStorageType(), namedAttr->name);
+ }
// TODO: This should use getter method to avoid duplication.
if (attr.hasDefaultValue()) {
@@ -887,7 +893,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
// That is precisely what getDiscardableAttr() returns on missing
// attributes.
} else {
- emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
+ emitMatchCheck(castedName, tgfmt("tblgen_attr", &fmtCtx),
formatv("\"expected op '{0}' to have attribute '{1}' "
"of type '{2}'\"",
op.getOperationName(), namedAttr->name,
@@ -918,7 +924,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
}
}
emitStaticVerifierCall(
- verifier, opName, "tblgen_attr",
+ verifier, castedName, "tblgen_attr",
formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
"'{2}'\"",
op.getOperationName(), namedAttr->name,
@@ -1532,6 +1538,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
LLVM_DEBUG(llvm::dbgs() << '\n');
Operator &resultOp = tree.getDialectOp(opMap);
+ bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
auto numOpArgs = resultOp.getNumArgs();
auto numPatArgs = tree.getNumArgs();
@@ -1623,9 +1630,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
// Then create the op.
- os.scope("", "\n}\n").os << formatv(
- "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
- valuePackName, resultOp.getQualCppClassName(), locToUse);
+ os.scope("", "\n}\n").os
+ << formatv("{0} = rewriter.create<{1}>({2}, tblgen_values, {3});",
+ valuePackName, resultOp.getQualCppClassName(), locToUse,
+ useProperties ? "tblgen_props" : "tblgen_attrs");
return resultValue;
}
@@ -1682,8 +1690,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
}
}
os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
- "tblgen_values, tblgen_attrs);\n",
- valuePackName, resultOp.getQualCppClassName(), locToUse);
+ "tblgen_values, {3});\n",
+ valuePackName, resultOp.getQualCppClassName(), locToUse,
+ useProperties ? "tblgen_props" : "tblgen_attrs");
os.unindent() << "}\n";
return resultValue;
}
@@ -1791,12 +1800,21 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
Operator &resultOp = node.getDialectOp(opMap);
+ bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
auto scope = os.scope();
os << formatv("::llvm::SmallVector<::mlir::Value, 4> "
"tblgen_values; (void)tblgen_values;\n");
- os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
- "tblgen_attrs; (void)tblgen_attrs;\n");
+ if (useProperties) {
+ os << formatv("{0}::Properties tblgen_props; (void)tblgen_props;\n",
+ resultOp.getQualCppClassName());
+ } else {
+ os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
+ "tblgen_attrs; (void)tblgen_attrs;\n");
+ }
+ const char *setPropCmd =
+ "tblgen_props.{0} = "
+ "::llvm::dyn_cast_if_present<decltype(tblgen_props.{0})>({1});\n";
const char *addAttrCmd =
"if (auto tmpAttr = {1}) {\n"
" tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
@@ -1814,13 +1832,23 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
if (!subTree.isNativeCodeCall())
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
- os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
+
+ if (useProperties) {
+ os << formatv(setPropCmd, opArgName, childNodeNames.lookup(argIndex));
+ } else {
+ os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
+ }
} else {
auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.
auto patArgName = node.getArgName(argIndex);
- os << formatv(addAttrCmd, opArgName,
- handleOpArgument(leaf, patArgName));
+ if (useProperties) {
+ os << formatv(setPropCmd, opArgName,
+ handleOpArgument(leaf, patArgName));
+ } else {
+ os << formatv(addAttrCmd, opArgName,
+ handleOpArgument(leaf, patArgName));
+ }
}
continue;
}
@@ -1876,11 +1904,18 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
const auto *sameVariadicSize =
resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
if (!sameVariadicSize) {
- const char *setSizes = R"(
- tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
- rewriter.getDenseI32ArrayAttr({{ {0} }));
- )";
- os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+ if (useProperties) {
+ const char *setSizes = R"(
+ tblgen_props.operandSegmentSizes = {{ {0} };
+ )";
+ os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+ } else {
+ const char *setSizes = R"(
+ tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
+ rewriter.getDenseI32ArrayAttr({{ {0} }));
+ )";
+ os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+ }
}
}
}
|
233c3eb
to
650bf05
Compare
650bf05
to
4e1dc86
Compare
a65fe2d
to
4e5cfed
Compare
4e1dc86
to
d059f73
Compare
Ping |
4e5cfed
to
aca6afa
Compare
d059f73
to
1c8f3eb
Compare
1c8f3eb
to
3abe842
Compare
Now that we have collective builders that take `const [RelevantOp]::Properties &` arguments, we don't need to serialize all the attributes that'll be set during an output pattern into a dictionary attribute. Similarly, we can use the properties struct to get the attributes instead of needing to go through the big if statement in getAttrOfType<>(). This also enables us to have declarative rewrite rules that match non-attribute properties in a future PR. This commit also adds a basic test for the generated matchers since there didn't seem to already be one.
3abe842
to
0ce3cec
Compare
Now that we have collective builders that take
const [RelevantOp]::Properties &
arguments, we don't need to serialize all the attributes that'll be set during an output pattern into a dictionary attribute. Similarly, we can use the properties struct to get the attributes instead of needing to go through the big if statement in getAttrOfType<>().This also enables us to have declarative rewrite rules that match non-attribute properties in a future PR.
This commit also adds a basic test for the generated matchers since there didn't seem to already be one.