Skip to content

[mlir][ODS] Switch declarative rewrite rules to properties structs #124876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 11, 2025

Conversation

krzysz00
Copy link
Contributor

Now that we have collective builders that take
const [RelevantOp]::Properties & arguments, we don't need to serialize all the attributes that'll be set during an output pattern into a dictionary attribute. Similarly, we can use the properties struct to get the attributes instead of needing to go through the big if statement in getAttrOfType<>().

This also enables us to have declarative rewrite rules that match non-attribute properties in a future PR.

This commit also adds a basic test for the generated matchers since there didn't seem to already be one.

@llvmbot
Copy link
Member

llvmbot commented Jan 29, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Krzysztof Drewniak (krzysz00)

Changes

Now that we have collective builders that take
const [RelevantOp]::Properties &amp; arguments, we don't need to serialize all the attributes that'll be set during an output pattern into a dictionary attribute. Similarly, we can use the properties struct to get the attributes instead of needing to go through the big if statement in getAttrOfType<>().

This also enables us to have declarative rewrite rules that match non-attribute properties in a future PR.

This commit also adds a basic test for the generated matchers since there didn't seem to already be one.


Full diff: https://github.com/llvm/llvm-project/pull/124876.diff

2 Files Affected:

  • (added) mlir/test/mlir-tblgen/rewriter-attributes-properties.td (+47)
  • (modified) mlir/tools/mlir-tblgen/RewriterGen.cpp (+58-23)
diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
new file mode 100644
index 00000000000000..77869d36cc12ee
--- /dev/null
+++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td
@@ -0,0 +1,47 @@
+// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/PatternBase.td"
+
+def Test_Dialect : Dialect {
+  let name = "test";
+}
+class NS_Op<string mnemonic, list<Trait> traits> :
+    Op<Test_Dialect, mnemonic, traits>;
+
+def AOp : NS_Op<"a_op", []> {
+  let arguments = (ins
+    I32:$x,
+    I32Attr:$y
+  );
+
+  let results = (outs I32:$z);
+}
+
+def BOp : NS_Op<"b_op", []> {
+  let arguments = (ins
+    I32Attr:$y
+  );
+
+  let results = (outs I32:$z);
+}
+
+def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>;
+// CHECK-LABEL: struct test1
+// CHECK: ::llvm::LogicalResult matchAndRewrite
+// CHECK: ::mlir::IntegerAttr y;
+// CHECK: test::BOp x;
+// CHECK: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
+// CHECK: tblgen_ops.push_back(op0);
+// CHECK: x = castedOp1;
+// CHECK: tblgen_attr = castedOp1.getProperties().getY();
+// CHECK: if (!(tblgen_attr))
+// CHECK: y = tblgen_attr;
+// CHECK: tblgen_ops.push_back(op1);
+
+// CHECK: test::AOp tblgen_AOp_0;
+// CHECK: ::llvm::SmallVector<::mlir::Value, 4> tblgen_values;
+// CHECK: test::AOp::Properties tblgen_props;
+// CHECK: tblgen_values.push_back((*x.getODSResults(0).begin()));
+// CHECK: tblgen_props.y = ::llvm::dyn_cast_if_present<decltype(tblgen_props.y)>(y);
+// CHECK: tblgen_AOp_0 = rewriter.create<test::AOp>(odsLoc, tblgen_types, tblgen_values, tblgen_props);
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index a041c4d3277798..9d8d20798dc8db 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -122,7 +122,7 @@ class PatternEmitter {
 
   // Emits C++ statements for matching the `argIndex`-th argument of the given
   // DAG `tree` as an attribute.
-  void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
+  void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex,
                           int depth);
 
   // Emits C++ for checking a match with a corresponding match failure
@@ -664,7 +664,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
                        /*variadicSubIndex=*/std::nullopt);
       ++nextOperand;
     } else if (isa<NamedAttribute *>(opArg)) {
-      emitAttributeMatch(tree, opName, opArgIdx, depth);
+      emitAttributeMatch(tree, castedName, opArgIdx, depth);
     } else {
       PrintFatalError(loc, "unhandled case when matching op");
     }
@@ -864,16 +864,22 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
   os.unindent() << "}\n";
 }
 
-void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
+void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName,
                                         int argIndex, int depth) {
   Operator &op = tree.getDialectOp(opMap);
   auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex));
   const auto &attr = namedAttr->attr;
 
   os << "{\n";
-  os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
-                         "(void)tblgen_attr;\n",
-                         opName, attr.getStorageType(), namedAttr->name);
+  if (op.getDialect().usePropertiesForAttributes()) {
+    os.indent() << formatv("auto tblgen_attr = {0}.getProperties().{1}();\n",
+                           castedName, op.getGetterName(namedAttr->name));
+  } else {
+    os.indent() << formatv(
+        "auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
+        "(void)tblgen_attr;\n",
+        castedName, attr.getStorageType(), namedAttr->name);
+  }
 
   // TODO: This should use getter method to avoid duplication.
   if (attr.hasDefaultValue()) {
@@ -887,7 +893,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
     // That is precisely what getDiscardableAttr() returns on missing
     // attributes.
   } else {
-    emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
+    emitMatchCheck(castedName, tgfmt("tblgen_attr", &fmtCtx),
                    formatv("\"expected op '{0}' to have attribute '{1}' "
                            "of type '{2}'\"",
                            op.getOperationName(), namedAttr->name,
@@ -918,7 +924,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
       }
     }
     emitStaticVerifierCall(
-        verifier, opName, "tblgen_attr",
+        verifier, castedName, "tblgen_attr",
         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
                 "'{2}'\"",
                 op.getOperationName(), namedAttr->name,
@@ -1532,6 +1538,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
   LLVM_DEBUG(llvm::dbgs() << '\n');
 
   Operator &resultOp = tree.getDialectOp(opMap);
+  bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
   auto numOpArgs = resultOp.getNumArgs();
   auto numPatArgs = tree.getNumArgs();
 
@@ -1623,9 +1630,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
     createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
 
     // Then create the op.
-    os.scope("", "\n}\n").os << formatv(
-        "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
-        valuePackName, resultOp.getQualCppClassName(), locToUse);
+    os.scope("", "\n}\n").os
+        << formatv("{0} = rewriter.create<{1}>({2}, tblgen_values, {3});",
+                   valuePackName, resultOp.getQualCppClassName(), locToUse,
+                   useProperties ? "tblgen_props" : "tblgen_attrs");
     return resultValue;
   }
 
@@ -1682,8 +1690,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
     }
   }
   os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
-                "tblgen_values, tblgen_attrs);\n",
-                valuePackName, resultOp.getQualCppClassName(), locToUse);
+                "tblgen_values, {3});\n",
+                valuePackName, resultOp.getQualCppClassName(), locToUse,
+                useProperties ? "tblgen_props" : "tblgen_attrs");
   os.unindent() << "}\n";
   return resultValue;
 }
@@ -1791,12 +1800,21 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
   Operator &resultOp = node.getDialectOp(opMap);
 
+  bool useProperties = resultOp.getDialect().usePropertiesForAttributes();
   auto scope = os.scope();
   os << formatv("::llvm::SmallVector<::mlir::Value, 4> "
                 "tblgen_values; (void)tblgen_values;\n");
-  os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
-                "tblgen_attrs; (void)tblgen_attrs;\n");
+  if (useProperties) {
+    os << formatv("{0}::Properties tblgen_props; (void)tblgen_props;\n",
+                  resultOp.getQualCppClassName());
+  } else {
+    os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
+                  "tblgen_attrs; (void)tblgen_attrs;\n");
+  }
 
+  const char *setPropCmd =
+      "tblgen_props.{0} = "
+      "::llvm::dyn_cast_if_present<decltype(tblgen_props.{0})>({1});\n";
   const char *addAttrCmd =
       "if (auto tmpAttr = {1}) {\n"
       "  tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
@@ -1814,13 +1832,23 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
         if (!subTree.isNativeCodeCall())
           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
                                "for creating attribute");
-        os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
+
+        if (useProperties) {
+          os << formatv(setPropCmd, opArgName, childNodeNames.lookup(argIndex));
+        } else {
+          os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
+        }
       } else {
         auto leaf = node.getArgAsLeaf(argIndex);
         // The argument in the result DAG pattern.
         auto patArgName = node.getArgName(argIndex);
-        os << formatv(addAttrCmd, opArgName,
-                      handleOpArgument(leaf, patArgName));
+        if (useProperties) {
+          os << formatv(setPropCmd, opArgName,
+                        handleOpArgument(leaf, patArgName));
+        } else {
+          os << formatv(addAttrCmd, opArgName,
+                        handleOpArgument(leaf, patArgName));
+        }
       }
       continue;
     }
@@ -1876,11 +1904,18 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
     const auto *sameVariadicSize =
         resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
     if (!sameVariadicSize) {
-      const char *setSizes = R"(
-        tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
-          rewriter.getDenseI32ArrayAttr({{ {0} }));
-          )";
-      os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+      if (useProperties) {
+        const char *setSizes = R"(
+          tblgen_props.operandSegmentSizes = {{ {0} };
+        )";
+        os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+      } else {
+        const char *setSizes = R"(
+          tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
+            rewriter.getDenseI32ArrayAttr({{ {0} }));
+            )";
+        os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
+      }
     }
   }
 }

@krzysz00 krzysz00 force-pushed the users/krzysz00/declaritve-rewrite-with-properties branch from 233c3eb to 650bf05 Compare January 29, 2025 06:33
@krzysz00 krzysz00 force-pushed the users/krzysz00/declaritve-rewrite-with-properties branch from 650bf05 to 4e1dc86 Compare February 6, 2025 06:31
@krzysz00 krzysz00 force-pushed the users/krzysz00/collective-prop-builder branch from a65fe2d to 4e5cfed Compare February 12, 2025 07:00
@krzysz00 krzysz00 force-pushed the users/krzysz00/declaritve-rewrite-with-properties branch from 4e1dc86 to d059f73 Compare February 12, 2025 07:00
@krzysz00
Copy link
Contributor Author

krzysz00 commented Mar 2, 2025

Ping

@akuegel akuegel removed their request for review March 3, 2025 08:27
@krzysz00 krzysz00 force-pushed the users/krzysz00/collective-prop-builder branch from 4e5cfed to aca6afa Compare March 5, 2025 04:34
@krzysz00 krzysz00 force-pushed the users/krzysz00/declaritve-rewrite-with-properties branch from d059f73 to 1c8f3eb Compare March 5, 2025 04:34
Base automatically changed from users/krzysz00/collective-prop-builder to main March 5, 2025 19:17
@krzysz00 krzysz00 force-pushed the users/krzysz00/declaritve-rewrite-with-properties branch from 1c8f3eb to 3abe842 Compare March 8, 2025 23:44
Now that we have collective builders that take
`const [RelevantOp]::Properties &` arguments, we don't need to serialize
all the attributes that'll be set during an output pattern into a dictionary
attribute. Similarly, we can use the properties struct to get the attributes
instead of needing to go through the big if statement in getAttrOfType<>().

This also enables us to have declarative rewrite rules that match non-attribute
properties in a future PR.

This commit also adds a basic test for the generated matchers since there
didn't seem to already be one.
@krzysz00 krzysz00 force-pushed the users/krzysz00/declaritve-rewrite-with-properties branch from 3abe842 to 0ce3cec Compare March 11, 2025 02:57
@krzysz00 krzysz00 merged commit f3e5594 into main Mar 11, 2025
11 checks passed
@krzysz00 krzysz00 deleted the users/krzysz00/declaritve-rewrite-with-properties branch March 11, 2025 15:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants