Skip to content

[mlir] use irdl as matcher description in transform #89779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 2, 2024
Merged

Conversation

ftynse
Copy link
Member

@ftynse ftynse commented Apr 23, 2024

Introduce a new Transform dialect extension that uses IRDL op definitions as matcher descriptors. IRDL allows one to essentially define additional op constraits to be verified and, unlike PDL, does not assume rewriting will happen. Leverage IRDL verification capability to filter out ops that match an IRDL definition without actually registering the corresponding operation with the system.

Introduce a new Transform dialect extension that uses IRDL op
definitions as matcher descriptors. IRDL allows one to essentially
define additional op constraits to be verified and, unlike PDL, does not
assume rewriting will happen. Leverage IRDL verification capability to
filter out ops that match an IRDL definition without actually
registering the corresponding operation with the system.
@ftynse ftynse requested a review from martin-luecke April 23, 2024 15:23
@math-fehr
Copy link
Contributor

I'm not that familiar with the transform side of things, but that looks really nice! (And the IRDL changes make much sense)

Copy link
Contributor

@martin-luecke martin-luecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment on lines +76 to +79
if (!dialect.getOps<irdl::TypeOp>().empty() ||
!dialect.getOps<irdl::AttributeOp>().empty()) {
return emitOpError() << "IRDL types and attributes are not yet supported";
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are irdl::TypeOp and irdl::AttributeOp not be handled automatically when calling irdl::createVerifier on the irdl::Operation or do we need to explicitly process them before the operation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to explicitly process them

// Set the verifier for types.
WalkResult res = op.walk([&](TypeOp typeOp) {
DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier(
typeOp, dialects[typeOp.getParentOp()], types, attrs);
if (!verifier)
return WalkResult::interrupt();
types[typeOp]->setVerifyFn(std::move(verifier));
return WalkResult::advance();
});
if (res.wasInterrupted())
return failure();
// Set the verifier for attributes.
res = op.walk([&](AttributeOp attrOp) {
DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier(
attrOp, dialects[attrOp.getParentOp()], types, attrs);
if (!verifier)
return WalkResult::interrupt();
attrs[attrOp]->setVerifyFn(std::move(verifier));
return WalkResult::advance();
});
if (res.wasInterrupted())
return failure();
. I would like to keep this out of the initial PR.

@ftynse ftynse marked this pull request as ready for review May 2, 2024 13:02
@ftynse ftynse requested a review from nicolasvasilache as a code owner May 2, 2024 13:02
@llvmbot
Copy link
Member

llvmbot commented May 2, 2024

@llvm/pr-subscribers-mlir-irdl

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

Changes

Introduce a new Transform dialect extension that uses IRDL op definitions as matcher descriptors. IRDL allows one to essentially define additional op constraits to be verified and, unlike PDL, does not assume rewriting will happen. Leverage IRDL verification capability to filter out ops that match an IRDL definition without actually registering the corresponding operation with the system.


Patch is 20.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89779.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h (+11)
  • (modified) mlir/include/mlir/Dialect/Transform/CMakeLists.txt (+1)
  • (added) mlir/include/mlir/Dialect/Transform/IRDLExtension/CMakeLists.txt (+6)
  • (added) mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h (+21)
  • (added) mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h (+20)
  • (added) mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.td (+36)
  • (modified) mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h (+5)
  • (modified) mlir/include/mlir/InitAllExtensions.h (+2)
  • (modified) mlir/lib/Dialect/IRDL/IRDLLoading.cpp (+44-27)
  • (modified) mlir/lib/Dialect/Transform/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Transform/IRDLExtension/CMakeLists.txt (+12)
  • (added) mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp (+34)
  • (added) mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp (+84)
  • (added) mlir/test/Dialect/Transform/irdl.mlir (+25)
diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
index 9ecb7c0107d7f8..89e99a63a5f104 100644
--- a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
+++ b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
@@ -30,7 +30,10 @@ class DynamicTypeDefinition;
 namespace mlir {
 namespace irdl {
 
+class AttributeOp;
 class Constraint;
+class OperationOp;
+class TypeOp;
 
 /// Provides context to the verification of constraints.
 /// It contains the assignment of variables to attributes, and the assignment
@@ -246,6 +249,14 @@ struct RegionConstraint {
   std::optional<SmallVector<unsigned>> argumentConstraints;
   std::optional<size_t> blockCount;
 };
+
+/// Generate an op verifier function from the given IRDL operation definition.
+llvm::unique_function<LogicalResult(Operation *) const> createVerifier(
+    OperationOp operation,
+    const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>>
+        &typeDefs,
+    const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
+        &attrDefs);
 } // namespace irdl
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
index 0cd71ec6919d9e..b6155b5f573f1b 100644
--- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_subdirectory(DebugExtension)
 add_subdirectory(Interfaces)
 add_subdirectory(IR)
+add_subdirectory(IRDLExtension)
 add_subdirectory(LoopExtension)
 add_subdirectory(PDLExtension)
 add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Transform/IRDLExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/IRDLExtension/CMakeLists.txt
new file mode 100644
index 00000000000000..dfcd906b43af04
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IRDLExtension/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS IRDLExtensionOps.td)
+mlir_tablegen(IRDLExtensionOps.h.inc -gen-op-decls)
+mlir_tablegen(IRDLExtensionOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTransformDialectIRDLExtensionOpsIncGen)
+
+add_mlir_doc(IRDLExtensionOps IRDLExtensionOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h b/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h
new file mode 100644
index 00000000000000..19684e1ed44468
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h
@@ -0,0 +1,21 @@
+//===- IRDLExtension.h - IRDL extension for Transform dialect ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSION_H
+#define MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSION_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace transform {
+/// Registers the IRDL extension of the Transform dialect in the given registry.
+void registerIRDLExtension(DialectRegistry &dialectRegistry);
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSION_H
diff --git a/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h b/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h
new file mode 100644
index 00000000000000..7e1d5cad1fbd88
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h
@@ -0,0 +1,20 @@
+//===- IRDLExtensionOps.h - IRDL Transform dialect extension ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS_H
+#define MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h.inc"
+
+#endif // MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS_H
diff --git a/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.td b/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.td
new file mode 100644
index 00000000000000..6ca624aeda12c7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.td
@@ -0,0 +1,36 @@
+//===- IRDLExtensionOps.td - Transform dialect extension ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS
+#define MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/SymbolInterfaces.td"
+
+def IRDLCollectMatchingOp : TransformDialectOp<"irdl.collect_matching",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     SymbolTable,
+     NoTerminator]> {
+  let summary = 
+    "Finds ops that match the IRDL definition without registering them.";
+
+  let arguments = (ins TransformHandleTypeInterface:$root);
+  let regions = (region SizedRegion<1>:$body);
+  let results = (outs TransformHandleTypeInterface:$matched);
+
+  let assemblyFormat =
+    "`in` $root `:` functional-type(operands, results) attr-dict-with-keyword "
+    "regions";
+
+  let hasVerifier = 1;
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS
diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h
index 08915213cd22c5..bf5a105bc9f29b 100644
--- a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h
+++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h
@@ -6,6 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSION_H
+#define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSION_H
+
 namespace mlir {
 class DialectRegistry;
 
@@ -14,3 +17,5 @@ namespace transform {
 void registerPDLExtension(DialectRegistry &dialectRegistry);
 } // namespace transform
 } // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSION_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 7708ca5571de3b..20a4ab6f18a286 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -35,6 +35,7 @@
 #include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
 #include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
 #include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
 #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
@@ -77,6 +78,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   sparse_tensor::registerTransformDialectExtension(registry);
   tensor::registerTransformDialectExtension(registry);
   transform::registerDebugExtension(registry);
+  transform::registerIRDLExtension(registry);
   transform::registerLoopExtension(registry);
   transform::registerPDLExtension(registry);
   vector::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index cfc8d092c8178a..5df2b45d8037b3 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -270,26 +270,30 @@ static LogicalResult irdlRegionVerifier(
   return success();
 }
 
-/// Define and load an operation represented by a `irdl.operation`
-/// operation.
-static WalkResult loadOperation(
-    OperationOp op, ExtensibleDialect *dialect,
-    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
-    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+llvm::unique_function<LogicalResult(Operation *) const>
+mlir::irdl::createVerifier(
+    OperationOp op,
+    const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
+    const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
+        &attrs) {
   // Resolve SSA values to verifier constraint slots
   SmallVector<Value> constrToValue;
   SmallVector<Value> regionToValue;
   for (Operation &op : op->getRegion(0).getOps()) {
     if (isa<VerifyConstraintInterface>(op)) {
-      if (op.getNumResults() != 1)
-        return op.emitError()
-               << "IRDL constraint operations must have exactly one result";
+      if (op.getNumResults() != 1) {
+        op.emitError()
+            << "IRDL constraint operations must have exactly one result";
+        return nullptr;
+      }
       constrToValue.push_back(op.getResult(0));
     }
     if (isa<VerifyRegionInterface>(op)) {
-      if (op.getNumResults() != 1)
-        return op.emitError()
-               << "IRDL constraint operations must have exactly one result";
+      if (op.getNumResults() != 1) {
+        op.emitError()
+            << "IRDL constraint operations must have exactly one result";
+        return nullptr;
+      }
       regionToValue.push_back(op.getResult(0));
     }
   }
@@ -302,7 +306,7 @@ static WalkResult loadOperation(
     std::unique_ptr<Constraint> verifier =
         op.getVerifier(constrToValue, types, attrs);
     if (!verifier)
-      return WalkResult::interrupt();
+      return nullptr;
     constraints.push_back(std::move(verifier));
   }
 
@@ -358,7 +362,7 @@ static WalkResult loadOperation(
   }
 
   // Gather which constraint slots correspond to attributes constraints
-  DenseMap<StringAttr, size_t> attributesContraints;
+  DenseMap<StringAttr, size_t> attributeConstraints;
   auto attributesOp = op.getOp<AttributesOp>();
   if (attributesOp.has_value()) {
     const Operation::operand_range values = attributesOp->getAttributeValues();
@@ -367,40 +371,53 @@ static WalkResult loadOperation(
     for (const auto &[name, value] : llvm::zip(names, values)) {
       for (auto [i, constr] : enumerate(constrToValue)) {
         if (constr == value) {
-          attributesContraints[cast<StringAttr>(name)] = i;
+          attributeConstraints[cast<StringAttr>(name)] = i;
           break;
         }
       }
     }
   }
 
-  // IRDL does not support defining custom parsers or printers.
-  auto parser = [](OpAsmParser &parser, OperationState &result) {
-    return failure();
-  };
-  auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) {
-    printer.printGenericOp(op);
-  };
-
-  auto verifier =
+  return
       [constraints{std::move(constraints)},
        regionConstraints{std::move(regionConstraints)},
        operandConstraints{std::move(operandConstraints)},
        operandVariadicity{std::move(operandVariadicity)},
        resultConstraints{std::move(resultConstraints)},
        resultVariadicity{std::move(resultVariadicity)},
-       attributesContraints{std::move(attributesContraints)}](Operation *op) {
+       attributeConstraints{std::move(attributeConstraints)}](Operation *op) {
         ConstraintVerifier verifier(constraints);
         const LogicalResult opVerifierResult = irdlOpVerifier(
             op, verifier, operandConstraints, operandVariadicity,
-            resultConstraints, resultVariadicity, attributesContraints);
+            resultConstraints, resultVariadicity, attributeConstraints);
         const LogicalResult opRegionVerifierResult =
             irdlRegionVerifier(op, verifier, regionConstraints);
         return LogicalResult::success(opVerifierResult.succeeded() &&
                                       opRegionVerifierResult.succeeded());
       };
+}
+
+/// Define and load an operation represented by a `irdl.operation`
+/// operation.
+static WalkResult loadOperation(
+    OperationOp op, ExtensibleDialect *dialect,
+    const DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
+    const DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
+        &attrs) {
+
+  // IRDL does not support defining custom parsers or printers.
+  auto parser = [](OpAsmParser &parser, OperationState &result) {
+    return failure();
+  };
+  auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) {
+    printer.printGenericOp(op);
+  };
+
+  auto verifier = createVerifier(op, types, attrs);
+  if (!verifier)
+    return WalkResult::interrupt();
 
-  // IRDL supports only checking number of blocks and argument contraints
+  // IRDL supports only checking number of blocks and argument constraints
   // It is done in the main verifier to reuse `ConstraintVerifier` context
   auto regionVerifier = [](Operation *op) { return LogicalResult::success(); };
 
diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt
index 64115dcc29d639..0c0d5ebe0c212e 100644
--- a/mlir/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_subdirectory(DebugExtension)
 add_subdirectory(Interfaces)
 add_subdirectory(IR)
+add_subdirectory(IRDLExtension)
 add_subdirectory(LoopExtension)
 add_subdirectory(PDLExtension)
 add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Transform/IRDLExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/IRDLExtension/CMakeLists.txt
new file mode 100644
index 00000000000000..9216a3d722021f
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/IRDLExtension/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_dialect_library(MLIRTransformDialectIRDLExtension
+  IRDLExtension.cpp
+  IRDLExtensionOps.cpp
+
+  DEPENDS
+  MLIRTransformDialectIRDLExtensionOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRTransformDialect
+  MLIRIRDL
+)
diff --git a/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp b/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp
new file mode 100644
index 00000000000000..94004365b8a1a5
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp
@@ -0,0 +1,34 @@
+//===- IRDLExtension.cpp - IRDL extension for the Transform dialect -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h"
+#include "mlir/IR/DialectRegistry.h"
+
+using namespace mlir;
+
+namespace {
+class IRDLExtension
+    : public transform::TransformDialectExtension<IRDLExtension> {
+public:
+  void init() {
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp.inc"
+        >();
+
+    declareDependentDialect<irdl::IRDLDialect>();
+  }
+};
+} // namespace
+
+void mlir::transform::registerIRDLExtension(DialectRegistry &dialectRegistry) {
+  dialectRegistry.addExtensions<IRDLExtension>();
+}
diff --git a/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp
new file mode 100644
index 00000000000000..9cc579e65edf91
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp
@@ -0,0 +1,84 @@
+//===- IRDLExtensionOps.cpp - IRDL extension for the Transform dialect ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/STLExtras.h"
+
+using namespace mlir;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp.inc"
+
+namespace mlir::transform {
+
+DiagnosedSilenceableFailure
+IRDLCollectMatchingOp::apply(TransformRewriter &rewriter,
+                             TransformResults &results, TransformState &state) {
+  auto dialect = cast<irdl::DialectOp>(getBody().front().front());
+  Block &body = dialect.getBody().front();
+  irdl::OperationOp operation = *body.getOps<irdl::OperationOp>().begin();
+  auto verifier = irdl::createVerifier(
+      operation,
+      DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>>(),
+      DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>());
+
+  auto handlerID = getContext()->getDiagEngine().registerHandler(
+      [](Diagnostic &) { return success(); });
+  SmallVector<Operation *> matched;
+  for (Operation *payload : state.getPayloadOps(getRoot())) {
+    payload->walk([&](Operation *target) {
+      if (succeeded(verifier(target))) {
+        matched.push_back(target);
+      }
+    });
+  }
+  getContext()->getDiagEngine().eraseHandler(handlerID);
+  results.set(cast<OpResult>(getMatched()), matched);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void IRDLCollectMatchingOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getRoot(), effects);
+  producesHandle(getMatched(), effects);
+  onlyReadsPayload(effects);
+}
+
+LogicalResult IRDLCollectMatchingOp::verify() {
+  Block &bodyBlock = getBody().front();
+  if (!llvm::hasSingleElement(bodyBlock))
+    return emitOpError() << "expects a single operation in the body";
+
+  auto dialect = dyn_cast<irdl::DialectOp>(bodyBlock.front());
+  if (!dialect) {
+    return emitOpError() << "expects the body operation to be "
+                         << irdl::DialectOp::getOperationName();
+  }
+
+  // TODO: relax this by taking a symbol name of the operation to match, note
+  // that symbol name is also the name of the operation and we may want to
+  // divert from that to have constraints on-the-fly using IRDL.
+  auto irdlOperations = dialect.getOps<irdl::OperationOp>();
+  if (!llvm::hasSingleElement(irdlOperations))
+    return emitOpError() << "expects IRDL to contain exactly one operation";
+
+  if (!dialect.getOps<irdl::TypeOp>().empty() ||
+      !dialect.getOps<irdl::AttributeOp>().empty()) {
+    return emitOpError() << "IRDL types and attributes are not yet supported";
+  }
+
+  return success();
+}
+
+} // namespace mlir::transform
diff --git a/mlir/test/Dialect/Transform/irdl.mlir b/mlir/test/Dialect/Transform/irdl.mlir
new file mode 100644
index 00000000000000..d3faea0dffcc26
--- /dev/null
+++ b/mlir/test/Dialect/Transform/irdl.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %0 = transform.irdl.collect_matching in %arg0 : (!transform.any_op) -> (!transform.any_op){
+    ^bb0(%arg1: !transform.any_op):
+      irdl.dialect @test {
+        irdl.operation @whatever {
+          %0 = irdl.is i32
+          %1 = irdl.i...
[truncated]

@ftynse ftynse merged commit 105c992 into llvm:main May 2, 2024
@ftynse ftynse deleted the td-irdl-2 branch May 2, 2024 13:03
@joker-eph
Copy link
Collaborator

@ftynse : this seems not well plugged in the doc generation, on the website it ends up with an empty box in the dialect list in the menu:

Screenshot 2024-05-15 at 3 17 48 PM

@ftynse
Copy link
Member Author

ftynse commented May 16, 2024

Thanks for noticing, 51403ad should fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants