Skip to content

[mlir] Add config for PDL #69927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 3, 2024
Merged

[mlir] Add config for PDL #69927

merged 1 commit into from
Jan 3, 2024

Conversation

jpienaar
Copy link
Member

@jpienaar jpienaar commented Oct 23, 2023

Make it so that PDL in pattern rewrites can be optionally disabled.

PDL is still enabled by default and not optional bazel. So this should be a NOP for most folks, while enabling other to disable.

This is piped through mlir-tblgen invocation and that could be changed/avoided by splitting up the passes file instead.

This only works with tests disabled. With tests enabled this still compiles but tests fail as there is no lit config to disable tests that depend on PDL rewrites yet.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:python MLIR Python bindings mlir mlir:bufferization Bufferization infrastructure labels Oct 23, 2023
@joker-eph
Copy link
Collaborator

Can you expand on the motivation? Why is it a problem that PDL is always included? Your description isn’t very explicit on the impact.

@jpienaar
Copy link
Member Author

Can you expand on the motivation? Why is it a problem that PDL is always included? Your description isn’t very explicit on the impact.

Done. Its not PDL specific, I think it is a problem if any dialect is always included even if not used :) The others just have a simple method to elide.

@joker-eph
Copy link
Collaborator

joker-eph commented Oct 28, 2023

Need to iterate a bit more on framing the problem:

MLIR is a general infrastructure and as far as possible no dialect is intended to be special or privileged

Yes, and you can use MLIR without PDL, it's not a mandatory component. I believe that the minimal examples mlir-cat and mlir-minimal-opt don't have PDL linked in (so you can already have your *-opt tool without it!

(there are other things I'd like to make configurable, like the printer or the canonicalization patterns! These are always linked in wether you use it or not)

Users should be able to use MLIR and the parts they need to customize for their solution (e.g., one doesn't have to include any dialect except the ones one uses).

I don't see PDL as "just a dialect": I see it more as an infrastructure component.
That is no one will have PDL used within their IR!
It isn't comparable to any other dialect and completely unique from this point of view (do we have other dialects that aren't targeting user-written compilers? The transform dialect might be in-between)

Many parts are elided when not referenced (LTO DCE'd etc), but this is not possible with PDL given how its integrated with the common rewrite drivers. This results in it always being included even when not used.

Now that seems more accurate to me: the dependency on PDL is only from the rewrite drivers, which is a "bring your own" thing by the way.
So the underlying question for motivating this here would be better framed IMO as "should we have an option to build the GreedyPatternRewriter without PDL?"
I see it less as a strong need here (because again, bring your own driver if you don't like it).

Without a stronger case, this seems like something I would be supportive if we can make it minimally intrusive: that is localize the changes to a maximum and not spread through the codebase.
There are far too many #ifdef to me right now, this may need some more refactoring first to avoid this and make PDL a more proper "separate component" inside libMLIRRewrite so that is can be enabled/disabled more naturally.

@llvmbot
Copy link
Member

llvmbot commented Dec 10, 2023

@llvm/pr-subscribers-mlir-pdl

@llvm/pr-subscribers-mlir-vector

Author: Jacques Pienaar (jpienaar)

Changes

Make it so that PDL can be optionally disabled. PDL is different than other dialects as its included in the core rewrite framework. This results in these being included even where it isn't used and not removed during compilation. Add option to disable for workloads where it isn't needed or can't be used. This ends up being rather invasive due to how PDL is included. Ideally we'd have less #if's, but didn't want to change structure purely to reduce those (felt better to keep the organization in case anything needs to be changed, forgetting to guard results in compile time failures which should be easy to trace).

MLIR is a general infrastructure and as far as possible no dialect is intended to be special or privileged (we have Builtin which is both and should be addressed ...). Users should be able to use MLIR and the parts they need to customize for their solution (e.g., one doesn't have to include any dialect except the ones one uses). Many parts are elided when not referenced (LTO DCE'd etc), but this is not possible with PDL given how its integrated with the common rewrite drivers. This results in it always being included even when not used. The most obvious impact is reduction in size when using rewrites but not PDL - reduced increase in size when adding rewrite driver in minimal case from 45% to 12%.

PDL is still enabled by default and not optional bazel. So this should be a NOP for most folks, while enabling other to disable.

This is piped through mlir-tblgen invocation and that could be changed/avoided by splitting up the passes file instead.

This only works with tests disabled. With tests enabled this still compiles but tests fail as there is no lit config to disable tests that depend on PDL yet.


Patch is 102.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69927.diff

28 Files Affected:

  • (modified) mlir/CMakeLists.txt (+4-2)
  • (modified) mlir/examples/minimal-opt/README.md (+5-4)
  • (modified) mlir/include/mlir/Config/mlir-config.h.cmake (+3)
  • (modified) mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h (+1)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+1)
  • (added) mlir/include/mlir/IR/PDLPatternMatch.h.inc (+992)
  • (modified) mlir/include/mlir/IR/PatternMatch.h (+9-926)
  • (modified) mlir/include/mlir/InitAllDialects.h (+1)
  • (modified) mlir/include/mlir/InitAllExtensions.h (+2-1)
  • (modified) mlir/include/mlir/Rewrite/PatternApplicator.h (+1)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+15)
  • (modified) mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt (-1)
  • (modified) mlir/lib/IR/PatternMatch.cpp (+3)
  • (modified) mlir/lib/Rewrite/ByteCode.cpp (+2)
  • (modified) mlir/lib/Rewrite/ByteCode.h (+35)
  • (modified) mlir/lib/Rewrite/CMakeLists.txt (+11)
  • (modified) mlir/lib/Rewrite/FrozenRewritePatternSet.cpp (+9-2)
  • (modified) mlir/lib/Rewrite/PatternApplicator.cpp (+3-1)
  • (modified) mlir/lib/Tools/CMakeLists.txt (+1-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+2)
  • (modified) mlir/python/CMakeLists.txt (+19-19)
  • (modified) mlir/test/CMakeLists.txt (+8-4)
  • (modified) mlir/test/lib/Rewrite/CMakeLists.txt (-1)
  • (modified) mlir/test/lib/Transforms/CMakeLists.txt (+12-2)
  • (modified) mlir/tools/mlir-lsp-server/CMakeLists.txt (+4-2)
  • (modified) mlir/tools/mlir-opt/CMakeLists.txt (+5-3)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+10-6)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+1)
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 16ff950089734b..3ff21ed0d5aa81 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -133,6 +133,8 @@ set(MLIR_ENABLE_NVPTXCOMPILER 0 CACHE BOOL
     "Statically link the nvptxlibrary instead of calling ptxas as a subprocess \
     for compiling PTX to cubin")
 
+set(MLIR_ENABLE_PDL 1 CACHE BOOL "Enable PDL")
+
 option(MLIR_INCLUDE_TESTS
        "Generate build targets for the MLIR unit tests."
        ${LLVM_INCLUDE_TESTS})
@@ -180,10 +182,10 @@ include_directories( ${MLIR_INCLUDE_DIR})
 # from another directory like tools
 add_subdirectory(tools/mlir-tblgen)
 add_subdirectory(tools/mlir-linalg-ods-gen)
-add_subdirectory(tools/mlir-pdll)
-
 set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "")
 set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "")
+
+add_subdirectory(tools/mlir-pdll)
 set(MLIR_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}" CACHE INTERNAL "")
 set(MLIR_PDLL_TABLEGEN_TARGET "${MLIR_PDLL_TABLEGEN_TARGET}" CACHE INTERNAL "")
 
diff --git a/mlir/examples/minimal-opt/README.md b/mlir/examples/minimal-opt/README.md
index b8a455f7a79662..09d0f20c34e426 100644
--- a/mlir/examples/minimal-opt/README.md
+++ b/mlir/examples/minimal-opt/README.md
@@ -14,10 +14,10 @@ Below are some example measurements taken at the time of the LLVM 17 release,
 using clang-14 on a X86 Ubuntu and [bloaty](https://github.com/google/bloaty).
 
 |                                  | Base   | Os     | Oz     | Os LTO | Oz LTO |
-| :-----------------------------: | ------ | ------ | ------ | ------ | ------ |
-| `mlir-cat`                      | 1018kB | 836KB  | 879KB  | 697KB  | 649KB  |
-| `mlir-minimal-opt`              | 1.54MB | 1.25MB | 1.29MB | 1.10MB | 1.00MB |
-| `mlir-minimal-opt-canonicalize` | 2.24MB | 1.81MB | 1.86MB | 1.62MB | 1.48MB |
+| :------------------------------: | ------ | ------ | ------ | ------ | ------ |
+| `mlir-cat`                       | 1024KB |  840KB |  885KB |  706KB |  657KB |
+| `mlir-minimal-opt`               | 1.62MB | 1.32MB | 1.36MB | 1.17MB | 1.07MB |
+| `mlir-minimal-opt-canonicalize`  | 1.83MB | 1.40MB | 1.45MB | 1.25MB | 1.14MB |
 
 Base configuration:
 
@@ -32,6 +32,7 @@ cmake ../llvm/ -G Ninja \
    -DCMAKE_CXX_COMPILER=clang++ \
    -DLLVM_ENABLE_LLD=ON \
    -DLLVM_ENABLE_BACKTRACES=OFF \
+   -DMLIR_ENABLE_PDL=OFF \
    -DCMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=-Wl,-icf=all
 ```
 
diff --git a/mlir/include/mlir/Config/mlir-config.h.cmake b/mlir/include/mlir/Config/mlir-config.h.cmake
index efa77b2e5ce5db..71e394eeec6133 100644
--- a/mlir/include/mlir/Config/mlir-config.h.cmake
+++ b/mlir/include/mlir/Config/mlir-config.h.cmake
@@ -26,4 +26,7 @@
    numeric seed that is passed to the random number generator. */
 #cmakedefine MLIR_GREEDY_REWRITE_RANDOMIZER_SEED ${MLIR_GREEDY_REWRITE_RANDOMIZER_SEED}
 
+/* If set, enables PDL usage. */
+#cmakedefine01 MLIR_ENABLE_PDL
+
 #endif
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 74f9c977b70286..e228229302cff4 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -15,6 +15,7 @@
 #define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
 
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 59d585a77b1e29..9dfbd0ff31ecea 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -29,6 +29,7 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringExtras.h"
 
 // Pull in all enum type definitions and utility function declarations.
diff --git a/mlir/include/mlir/IR/PDLPatternMatch.h.inc b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
new file mode 100644
index 00000000000000..e4d63248bf1908
--- /dev/null
+++ b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
@@ -0,0 +1,992 @@
+//===- PDLPatternMatch.h - PDLPatternMatcher classes -------==---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_PDLPATTERNMATCH_H
+#define MLIR_IR_PDLPATTERNMATCH_H
+
+#if MLIR_ENABLE_PDL
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+
+namespace mlir {
+//===----------------------------------------------------------------------===//
+// PDL Patterns
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// PDLValue
+
+/// Storage type of byte-code interpreter values. These are passed to constraint
+/// functions as arguments.
+class PDLValue {
+public:
+  /// The underlying kind of a PDL value.
+  enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange };
+
+  /// Construct a new PDL value.
+  PDLValue(const PDLValue &other) = default;
+  PDLValue(std::nullptr_t = nullptr) {}
+  PDLValue(Attribute value)
+      : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
+  PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
+  PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
+  PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
+  PDLValue(Value value)
+      : value(value.getAsOpaquePointer()), kind(Kind::Value) {}
+  PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
+
+  /// Returns true if the type of the held value is `T`.
+  template <typename T>
+  bool isa() const {
+    assert(value && "isa<> used on a null value");
+    return kind == getKindOf<T>();
+  }
+
+  /// Attempt to dynamically cast this value to type `T`, returns null if this
+  /// value is not an instance of `T`.
+  template <typename T,
+            typename ResultT = std::conditional_t<
+                std::is_convertible<T, bool>::value, T, std::optional<T>>>
+  ResultT dyn_cast() const {
+    return isa<T>() ? castImpl<T>() : ResultT();
+  }
+
+  /// Cast this value to type `T`, asserts if this value is not an instance of
+  /// `T`.
+  template <typename T>
+  T cast() const {
+    assert(isa<T>() && "expected value to be of type `T`");
+    return castImpl<T>();
+  }
+
+  /// Get an opaque pointer to the value.
+  const void *getAsOpaquePointer() const { return value; }
+
+  /// Return if this value is null or not.
+  explicit operator bool() const { return value; }
+
+  /// Return the kind of this value.
+  Kind getKind() const { return kind; }
+
+  /// Print this value to the provided output stream.
+  void print(raw_ostream &os) const;
+
+  /// Print the specified value kind to an output stream.
+  static void print(raw_ostream &os, Kind kind);
+
+private:
+  /// Find the index of a given type in a range of other types.
+  template <typename...>
+  struct index_of_t;
+  template <typename T, typename... R>
+  struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
+  template <typename T, typename F, typename... R>
+  struct index_of_t<T, F, R...>
+      : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
+
+  /// Return the kind used for the given T.
+  template <typename T>
+  static Kind getKindOf() {
+    return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
+                                        TypeRange, Value, ValueRange>::value);
+  }
+
+  /// The internal implementation of `cast`, that returns the underlying value
+  /// as the given type `T`.
+  template <typename T>
+  std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
+  castImpl() const {
+    return T::getFromOpaquePointer(value);
+  }
+  template <typename T>
+  std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
+  castImpl() const {
+    return *reinterpret_cast<T *>(const_cast<void *>(value));
+  }
+  template <typename T>
+  std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
+    return reinterpret_cast<T>(const_cast<void *>(value));
+  }
+
+  /// The internal opaque representation of a PDLValue.
+  const void *value{nullptr};
+  /// The kind of the opaque value.
+  Kind kind{Kind::Attribute};
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
+  value.print(os);
+  return os;
+}
+
+inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
+  PDLValue::print(os, kind);
+  return os;
+}
+
+//===----------------------------------------------------------------------===//
+// PDLResultList
+
+/// The class represents a list of PDL results, returned by a native rewrite
+/// method. It provides the mechanism with which to pass PDLValues back to the
+/// PDL bytecode.
+class PDLResultList {
+public:
+  /// Push a new Attribute value onto the result list.
+  void push_back(Attribute value) { results.push_back(value); }
+
+  /// Push a new Operation onto the result list.
+  void push_back(Operation *value) { results.push_back(value); }
+
+  /// Push a new Type onto the result list.
+  void push_back(Type value) { results.push_back(value); }
+
+  /// Push a new TypeRange onto the result list.
+  void push_back(TypeRange value) {
+    // The lifetime of a TypeRange can't be guaranteed, so we'll need to
+    // allocate a storage for it.
+    llvm::OwningArrayRef<Type> storage(value.size());
+    llvm::copy(value, storage.begin());
+    allocatedTypeRanges.emplace_back(std::move(storage));
+    typeRanges.push_back(allocatedTypeRanges.back());
+    results.push_back(&typeRanges.back());
+  }
+  void push_back(ValueTypeRange<OperandRange> value) {
+    typeRanges.push_back(value);
+    results.push_back(&typeRanges.back());
+  }
+  void push_back(ValueTypeRange<ResultRange> value) {
+    typeRanges.push_back(value);
+    results.push_back(&typeRanges.back());
+  }
+
+  /// Push a new Value onto the result list.
+  void push_back(Value value) { results.push_back(value); }
+
+  /// Push a new ValueRange onto the result list.
+  void push_back(ValueRange value) {
+    // The lifetime of a ValueRange can't be guaranteed, so we'll need to
+    // allocate a storage for it.
+    llvm::OwningArrayRef<Value> storage(value.size());
+    llvm::copy(value, storage.begin());
+    allocatedValueRanges.emplace_back(std::move(storage));
+    valueRanges.push_back(allocatedValueRanges.back());
+    results.push_back(&valueRanges.back());
+  }
+  void push_back(OperandRange value) {
+    valueRanges.push_back(value);
+    results.push_back(&valueRanges.back());
+  }
+  void push_back(ResultRange value) {
+    valueRanges.push_back(value);
+    results.push_back(&valueRanges.back());
+  }
+
+protected:
+  /// Create a new result list with the expected number of results.
+  PDLResultList(unsigned maxNumResults) {
+    // For now just reserve enough space for all of the results. We could do
+    // separate counts per range type, but it isn't really worth it unless there
+    // are a "large" number of results.
+    typeRanges.reserve(maxNumResults);
+    valueRanges.reserve(maxNumResults);
+  }
+
+  /// The PDL results held by this list.
+  SmallVector<PDLValue> results;
+  /// Memory used to store ranges held by the list.
+  SmallVector<TypeRange> typeRanges;
+  SmallVector<ValueRange> valueRanges;
+  /// Memory allocated to store ranges in the result list whose lifetime was
+  /// generated in the native function.
+  SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges;
+  SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
+};
+
+//===----------------------------------------------------------------------===//
+// PDLPatternConfig
+
+/// An individual configuration for a pattern, which can be accessed by native
+/// functions via the PDLPatternConfigSet. This allows for injecting additional
+/// configuration into PDL patterns that is specific to certain compilation
+/// flows.
+class PDLPatternConfig {
+public:
+  virtual ~PDLPatternConfig() = default;
+
+  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
+  /// pattern. These can be used to setup any specific state necessary for the
+  /// rewrite.
+  virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
+  virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
+
+  /// Return the TypeID that represents this configuration.
+  TypeID getTypeID() const { return id; }
+
+protected:
+  PDLPatternConfig(TypeID id) : id(id) {}
+
+private:
+  TypeID id;
+};
+
+/// This class provides a base class for users implementing a type of pattern
+/// configuration.
+template <typename T>
+class PDLPatternConfigBase : public PDLPatternConfig {
+public:
+  /// Support LLVM style casting.
+  static bool classof(const PDLPatternConfig *config) {
+    return config->getTypeID() == getConfigID();
+  }
+
+  /// Return the type id used for this configuration.
+  static TypeID getConfigID() { return TypeID::get<T>(); }
+
+protected:
+  PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
+};
+
+/// This class contains a set of configurations for a specific pattern.
+/// Configurations are uniqued by TypeID, meaning that only one configuration of
+/// each type is allowed.
+class PDLPatternConfigSet {
+public:
+  PDLPatternConfigSet() = default;
+
+  /// Construct a set with the given configurations.
+  template <typename... ConfigsT>
+  PDLPatternConfigSet(ConfigsT &&...configs) {
+    (addConfig(std::forward<ConfigsT>(configs)), ...);
+  }
+
+  /// Get the configuration defined by the given type. Asserts that the
+  /// configuration of the provided type exists.
+  template <typename T>
+  const T &get() const {
+    const T *config = tryGet<T>();
+    assert(config && "configuration not found");
+    return *config;
+  }
+
+  /// Get the configuration defined by the given type, returns nullptr if the
+  /// configuration does not exist.
+  template <typename T>
+  const T *tryGet() const {
+    for (const auto &configIt : configs)
+      if (const T *config = dyn_cast<T>(configIt.get()))
+        return config;
+    return nullptr;
+  }
+
+  /// Notify the configurations within this set at the beginning or end of a
+  /// rewrite of a matched pattern.
+  void notifyRewriteBegin(PatternRewriter &rewriter) {
+    for (const auto &config : configs)
+      config->notifyRewriteBegin(rewriter);
+  }
+  void notifyRewriteEnd(PatternRewriter &rewriter) {
+    for (const auto &config : configs)
+      config->notifyRewriteEnd(rewriter);
+  }
+
+protected:
+  /// Add a configuration to the set.
+  template <typename T>
+  void addConfig(T &&config) {
+    assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
+    configs.emplace_back(
+        std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
+  }
+
+  /// The set of configurations for this pattern. This uses a vector instead of
+  /// a map with the expectation that the number of configurations per set is
+  /// small (<= 1).
+  SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
+};
+
+//===----------------------------------------------------------------------===//
+// PDLPatternModule
+
+/// A generic PDL pattern constraint function. This function applies a
+/// constraint to a given set of opaque PDLValue entities. Returns success if
+/// the constraint successfully held, failure otherwise.
+using PDLConstraintFunction =
+    std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
+/// A native PDL rewrite function. This function performs a rewrite on the
+/// given set of values. Any results from this rewrite that should be passed
+/// back to PDL should be added to the provided result list. This method is only
+/// invoked when the corresponding match was successful. Returns failure if an
+/// invariant of the rewrite was broken (certain rewriters may recover from
+/// partial pattern application).
+using PDLRewriteFunction = std::function<LogicalResult(
+    PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
+
+namespace detail {
+namespace pdl_function_builder {
+/// A utility variable that always resolves to false. This is useful for static
+/// asserts that are always false, but only should fire in certain templated
+/// constructs. For example, if a templated function should never be called, the
+/// function could be defined as:
+///
+/// template <typename T>
+/// void foo() {
+///  static_assert(always_false<T>, "This function should never be called");
+/// }
+///
+template <class... T>
+constexpr bool always_false = false;
+
+//===----------------------------------------------------------------------===//
+// PDL Function Builder: Type Processing
+//===----------------------------------------------------------------------===//
+
+/// This struct provides a convenient way to determine how to process a given
+/// type as either a PDL parameter, or a result value. This allows for
+/// supporting complex types in constraint and rewrite functions, without
+/// requiring the user to hand-write the necessary glue code themselves.
+/// Specializations of this class should implement the following methods to
+/// enable support as a PDL argument or result type:
+///
+///   static LogicalResult verifyAsArg(
+///     function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
+///     size_t argIdx);
+///
+///     * This method verifies that the given PDLValue is valid for use as a
+///       value of `T`.
+///
+///   static T processAsArg(PDLValue pdlValue);
+///
+///     *  This method processes the given PDLValue as a value of `T`.
+///
+///   static void processAsResult(PatternRewriter &, PDLResultList &results,
+///                               const T &value);
+///
+///     *  This method processes the given value of `T` as the result of a
+///        function invocation. The method should package the value into an
+///        appropriate form and append it to the given result list.
+///
+/// If the type `T` is based on a higher order value, consider using
+/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
+/// the implementation.
+///
+template <typename T, typename Enable = void>
+struct ProcessPDLValue;
+
+/// This struct provides a simplified model for processing types that are based
+/// on another type, e.g. APInt is based on the handling for IntegerAttr. This
+/// allows for building the necessary processing functions on top of the base
+/// value instead of a PDLValue. Derived users should implement the following
+/// (which subsume the ProcessPDLValue variants):
+///
+///   static LogicalResult verifyAsArg(
+///     function_ref<LogicalResult(const Twine &)> errorFn,
+///     const BaseT &baseValue, size_t argIdx);
+///
+///     * This method verifies that the given PDLValue is valid for use as a
+///       value of `T`.
+///
+///   static T processAsArg(BaseT baseValue);
+///
+///     *  This method processes the given base value as a value of `T`.
+///
+template <typename T, typename BaseT>
+struct ProcessPDLValueBasedOn {
+  static LogicalResult
+  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
+              PDLValue pdlValue, size_t argIdx) {
+    // Verify the base class before continuing.
+    if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
+      return failure();
+    return ProcessPDLValue<T>::verifyAsArg(
+        errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
+  }
+  static T processAsArg(PDLValue pdlValue) {
+    return ProcessPDLValue<T>::processAsArg(
+        ProcessPDLValue<BaseT>::processAsArg(pdlValue));
+  }
+
+  /// Explicitly add the expected parent API to ensure the parent class
+  /// implements the necessary API (and doesn't implicitly inherit it from
+  /// somewhere else).
+  static LogicalResult
+  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
+              size_t argIdx) {
+    return success();
+  }
+  static T processAsArg(BaseT baseValue);
+};
+
+/// ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 10, 2023

@llvm/pr-subscribers-mlir-llvm

Author: Jacques Pienaar (jpienaar)

Changes

Make it so that PDL can be optionally disabled. PDL is different than other dialects as its included in the core rewrite framework. This results in these being included even where it isn't used and not removed during compilation. Add option to disable for workloads where it isn't needed or can't be used. This ends up being rather invasive due to how PDL is included. Ideally we'd have less #if's, but didn't want to change structure purely to reduce those (felt better to keep the organization in case anything needs to be changed, forgetting to guard results in compile time failures which should be easy to trace).

MLIR is a general infrastructure and as far as possible no dialect is intended to be special or privileged (we have Builtin which is both and should be addressed ...). Users should be able to use MLIR and the parts they need to customize for their solution (e.g., one doesn't have to include any dialect except the ones one uses). Many parts are elided when not referenced (LTO DCE'd etc), but this is not possible with PDL given how its integrated with the common rewrite drivers. This results in it always being included even when not used. The most obvious impact is reduction in size when using rewrites but not PDL - reduced increase in size when adding rewrite driver in minimal case from 45% to 12%.

PDL is still enabled by default and not optional bazel. So this should be a NOP for most folks, while enabling other to disable.

This is piped through mlir-tblgen invocation and that could be changed/avoided by splitting up the passes file instead.

This only works with tests disabled. With tests enabled this still compiles but tests fail as there is no lit config to disable tests that depend on PDL yet.


Patch is 102.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69927.diff

28 Files Affected:

  • (modified) mlir/CMakeLists.txt (+4-2)
  • (modified) mlir/examples/minimal-opt/README.md (+5-4)
  • (modified) mlir/include/mlir/Config/mlir-config.h.cmake (+3)
  • (modified) mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h (+1)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+1)
  • (added) mlir/include/mlir/IR/PDLPatternMatch.h.inc (+992)
  • (modified) mlir/include/mlir/IR/PatternMatch.h (+9-926)
  • (modified) mlir/include/mlir/InitAllDialects.h (+1)
  • (modified) mlir/include/mlir/InitAllExtensions.h (+2-1)
  • (modified) mlir/include/mlir/Rewrite/PatternApplicator.h (+1)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+15)
  • (modified) mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt (-1)
  • (modified) mlir/lib/IR/PatternMatch.cpp (+3)
  • (modified) mlir/lib/Rewrite/ByteCode.cpp (+2)
  • (modified) mlir/lib/Rewrite/ByteCode.h (+35)
  • (modified) mlir/lib/Rewrite/CMakeLists.txt (+11)
  • (modified) mlir/lib/Rewrite/FrozenRewritePatternSet.cpp (+9-2)
  • (modified) mlir/lib/Rewrite/PatternApplicator.cpp (+3-1)
  • (modified) mlir/lib/Tools/CMakeLists.txt (+1-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+2)
  • (modified) mlir/python/CMakeLists.txt (+19-19)
  • (modified) mlir/test/CMakeLists.txt (+8-4)
  • (modified) mlir/test/lib/Rewrite/CMakeLists.txt (-1)
  • (modified) mlir/test/lib/Transforms/CMakeLists.txt (+12-2)
  • (modified) mlir/tools/mlir-lsp-server/CMakeLists.txt (+4-2)
  • (modified) mlir/tools/mlir-opt/CMakeLists.txt (+5-3)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+10-6)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+1)
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 16ff950089734b..3ff21ed0d5aa81 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -133,6 +133,8 @@ set(MLIR_ENABLE_NVPTXCOMPILER 0 CACHE BOOL
     "Statically link the nvptxlibrary instead of calling ptxas as a subprocess \
     for compiling PTX to cubin")
 
+set(MLIR_ENABLE_PDL 1 CACHE BOOL "Enable PDL")
+
 option(MLIR_INCLUDE_TESTS
        "Generate build targets for the MLIR unit tests."
        ${LLVM_INCLUDE_TESTS})
@@ -180,10 +182,10 @@ include_directories( ${MLIR_INCLUDE_DIR})
 # from another directory like tools
 add_subdirectory(tools/mlir-tblgen)
 add_subdirectory(tools/mlir-linalg-ods-gen)
-add_subdirectory(tools/mlir-pdll)
-
 set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "")
 set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "")
+
+add_subdirectory(tools/mlir-pdll)
 set(MLIR_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}" CACHE INTERNAL "")
 set(MLIR_PDLL_TABLEGEN_TARGET "${MLIR_PDLL_TABLEGEN_TARGET}" CACHE INTERNAL "")
 
diff --git a/mlir/examples/minimal-opt/README.md b/mlir/examples/minimal-opt/README.md
index b8a455f7a79662..09d0f20c34e426 100644
--- a/mlir/examples/minimal-opt/README.md
+++ b/mlir/examples/minimal-opt/README.md
@@ -14,10 +14,10 @@ Below are some example measurements taken at the time of the LLVM 17 release,
 using clang-14 on a X86 Ubuntu and [bloaty](https://github.com/google/bloaty).
 
 |                                  | Base   | Os     | Oz     | Os LTO | Oz LTO |
-| :-----------------------------: | ------ | ------ | ------ | ------ | ------ |
-| `mlir-cat`                      | 1018kB | 836KB  | 879KB  | 697KB  | 649KB  |
-| `mlir-minimal-opt`              | 1.54MB | 1.25MB | 1.29MB | 1.10MB | 1.00MB |
-| `mlir-minimal-opt-canonicalize` | 2.24MB | 1.81MB | 1.86MB | 1.62MB | 1.48MB |
+| :------------------------------: | ------ | ------ | ------ | ------ | ------ |
+| `mlir-cat`                       | 1024KB |  840KB |  885KB |  706KB |  657KB |
+| `mlir-minimal-opt`               | 1.62MB | 1.32MB | 1.36MB | 1.17MB | 1.07MB |
+| `mlir-minimal-opt-canonicalize`  | 1.83MB | 1.40MB | 1.45MB | 1.25MB | 1.14MB |
 
 Base configuration:
 
@@ -32,6 +32,7 @@ cmake ../llvm/ -G Ninja \
    -DCMAKE_CXX_COMPILER=clang++ \
    -DLLVM_ENABLE_LLD=ON \
    -DLLVM_ENABLE_BACKTRACES=OFF \
+   -DMLIR_ENABLE_PDL=OFF \
    -DCMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=-Wl,-icf=all
 ```
 
diff --git a/mlir/include/mlir/Config/mlir-config.h.cmake b/mlir/include/mlir/Config/mlir-config.h.cmake
index efa77b2e5ce5db..71e394eeec6133 100644
--- a/mlir/include/mlir/Config/mlir-config.h.cmake
+++ b/mlir/include/mlir/Config/mlir-config.h.cmake
@@ -26,4 +26,7 @@
    numeric seed that is passed to the random number generator. */
 #cmakedefine MLIR_GREEDY_REWRITE_RANDOMIZER_SEED ${MLIR_GREEDY_REWRITE_RANDOMIZER_SEED}
 
+/* If set, enables PDL usage. */
+#cmakedefine01 MLIR_ENABLE_PDL
+
 #endif
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 74f9c977b70286..e228229302cff4 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -15,6 +15,7 @@
 #define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
 
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 59d585a77b1e29..9dfbd0ff31ecea 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -29,6 +29,7 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringExtras.h"
 
 // Pull in all enum type definitions and utility function declarations.
diff --git a/mlir/include/mlir/IR/PDLPatternMatch.h.inc b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
new file mode 100644
index 00000000000000..e4d63248bf1908
--- /dev/null
+++ b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
@@ -0,0 +1,992 @@
+//===- PDLPatternMatch.h - PDLPatternMatcher classes -------==---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_PDLPATTERNMATCH_H
+#define MLIR_IR_PDLPATTERNMATCH_H
+
+#if MLIR_ENABLE_PDL
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+
+namespace mlir {
+//===----------------------------------------------------------------------===//
+// PDL Patterns
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// PDLValue
+
+/// Storage type of byte-code interpreter values. These are passed to constraint
+/// functions as arguments.
+class PDLValue {
+public:
+  /// The underlying kind of a PDL value.
+  enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange };
+
+  /// Construct a new PDL value.
+  PDLValue(const PDLValue &other) = default;
+  PDLValue(std::nullptr_t = nullptr) {}
+  PDLValue(Attribute value)
+      : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
+  PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
+  PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
+  PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
+  PDLValue(Value value)
+      : value(value.getAsOpaquePointer()), kind(Kind::Value) {}
+  PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
+
+  /// Returns true if the type of the held value is `T`.
+  template <typename T>
+  bool isa() const {
+    assert(value && "isa<> used on a null value");
+    return kind == getKindOf<T>();
+  }
+
+  /// Attempt to dynamically cast this value to type `T`, returns null if this
+  /// value is not an instance of `T`.
+  template <typename T,
+            typename ResultT = std::conditional_t<
+                std::is_convertible<T, bool>::value, T, std::optional<T>>>
+  ResultT dyn_cast() const {
+    return isa<T>() ? castImpl<T>() : ResultT();
+  }
+
+  /// Cast this value to type `T`, asserts if this value is not an instance of
+  /// `T`.
+  template <typename T>
+  T cast() const {
+    assert(isa<T>() && "expected value to be of type `T`");
+    return castImpl<T>();
+  }
+
+  /// Get an opaque pointer to the value.
+  const void *getAsOpaquePointer() const { return value; }
+
+  /// Return if this value is null or not.
+  explicit operator bool() const { return value; }
+
+  /// Return the kind of this value.
+  Kind getKind() const { return kind; }
+
+  /// Print this value to the provided output stream.
+  void print(raw_ostream &os) const;
+
+  /// Print the specified value kind to an output stream.
+  static void print(raw_ostream &os, Kind kind);
+
+private:
+  /// Find the index of a given type in a range of other types.
+  template <typename...>
+  struct index_of_t;
+  template <typename T, typename... R>
+  struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
+  template <typename T, typename F, typename... R>
+  struct index_of_t<T, F, R...>
+      : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
+
+  /// Return the kind used for the given T.
+  template <typename T>
+  static Kind getKindOf() {
+    return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
+                                        TypeRange, Value, ValueRange>::value);
+  }
+
+  /// The internal implementation of `cast`, that returns the underlying value
+  /// as the given type `T`.
+  template <typename T>
+  std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
+  castImpl() const {
+    return T::getFromOpaquePointer(value);
+  }
+  template <typename T>
+  std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
+  castImpl() const {
+    return *reinterpret_cast<T *>(const_cast<void *>(value));
+  }
+  template <typename T>
+  std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
+    return reinterpret_cast<T>(const_cast<void *>(value));
+  }
+
+  /// The internal opaque representation of a PDLValue.
+  const void *value{nullptr};
+  /// The kind of the opaque value.
+  Kind kind{Kind::Attribute};
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
+  value.print(os);
+  return os;
+}
+
+inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
+  PDLValue::print(os, kind);
+  return os;
+}
+
+//===----------------------------------------------------------------------===//
+// PDLResultList
+
+/// The class represents a list of PDL results, returned by a native rewrite
+/// method. It provides the mechanism with which to pass PDLValues back to the
+/// PDL bytecode.
+class PDLResultList {
+public:
+  /// Push a new Attribute value onto the result list.
+  void push_back(Attribute value) { results.push_back(value); }
+
+  /// Push a new Operation onto the result list.
+  void push_back(Operation *value) { results.push_back(value); }
+
+  /// Push a new Type onto the result list.
+  void push_back(Type value) { results.push_back(value); }
+
+  /// Push a new TypeRange onto the result list.
+  void push_back(TypeRange value) {
+    // The lifetime of a TypeRange can't be guaranteed, so we'll need to
+    // allocate a storage for it.
+    llvm::OwningArrayRef<Type> storage(value.size());
+    llvm::copy(value, storage.begin());
+    allocatedTypeRanges.emplace_back(std::move(storage));
+    typeRanges.push_back(allocatedTypeRanges.back());
+    results.push_back(&typeRanges.back());
+  }
+  void push_back(ValueTypeRange<OperandRange> value) {
+    typeRanges.push_back(value);
+    results.push_back(&typeRanges.back());
+  }
+  void push_back(ValueTypeRange<ResultRange> value) {
+    typeRanges.push_back(value);
+    results.push_back(&typeRanges.back());
+  }
+
+  /// Push a new Value onto the result list.
+  void push_back(Value value) { results.push_back(value); }
+
+  /// Push a new ValueRange onto the result list.
+  void push_back(ValueRange value) {
+    // The lifetime of a ValueRange can't be guaranteed, so we'll need to
+    // allocate a storage for it.
+    llvm::OwningArrayRef<Value> storage(value.size());
+    llvm::copy(value, storage.begin());
+    allocatedValueRanges.emplace_back(std::move(storage));
+    valueRanges.push_back(allocatedValueRanges.back());
+    results.push_back(&valueRanges.back());
+  }
+  void push_back(OperandRange value) {
+    valueRanges.push_back(value);
+    results.push_back(&valueRanges.back());
+  }
+  void push_back(ResultRange value) {
+    valueRanges.push_back(value);
+    results.push_back(&valueRanges.back());
+  }
+
+protected:
+  /// Create a new result list with the expected number of results.
+  PDLResultList(unsigned maxNumResults) {
+    // For now just reserve enough space for all of the results. We could do
+    // separate counts per range type, but it isn't really worth it unless there
+    // are a "large" number of results.
+    typeRanges.reserve(maxNumResults);
+    valueRanges.reserve(maxNumResults);
+  }
+
+  /// The PDL results held by this list.
+  SmallVector<PDLValue> results;
+  /// Memory used to store ranges held by the list.
+  SmallVector<TypeRange> typeRanges;
+  SmallVector<ValueRange> valueRanges;
+  /// Memory allocated to store ranges in the result list whose lifetime was
+  /// generated in the native function.
+  SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges;
+  SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
+};
+
+//===----------------------------------------------------------------------===//
+// PDLPatternConfig
+
+/// An individual configuration for a pattern, which can be accessed by native
+/// functions via the PDLPatternConfigSet. This allows for injecting additional
+/// configuration into PDL patterns that is specific to certain compilation
+/// flows.
+class PDLPatternConfig {
+public:
+  virtual ~PDLPatternConfig() = default;
+
+  /// Hooks that are invoked at the beginning and end of a rewrite of a matched
+  /// pattern. These can be used to setup any specific state necessary for the
+  /// rewrite.
+  virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
+  virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
+
+  /// Return the TypeID that represents this configuration.
+  TypeID getTypeID() const { return id; }
+
+protected:
+  PDLPatternConfig(TypeID id) : id(id) {}
+
+private:
+  TypeID id;
+};
+
+/// This class provides a base class for users implementing a type of pattern
+/// configuration.
+template <typename T>
+class PDLPatternConfigBase : public PDLPatternConfig {
+public:
+  /// Support LLVM style casting.
+  static bool classof(const PDLPatternConfig *config) {
+    return config->getTypeID() == getConfigID();
+  }
+
+  /// Return the type id used for this configuration.
+  static TypeID getConfigID() { return TypeID::get<T>(); }
+
+protected:
+  PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
+};
+
+/// This class contains a set of configurations for a specific pattern.
+/// Configurations are uniqued by TypeID, meaning that only one configuration of
+/// each type is allowed.
+class PDLPatternConfigSet {
+public:
+  PDLPatternConfigSet() = default;
+
+  /// Construct a set with the given configurations.
+  template <typename... ConfigsT>
+  PDLPatternConfigSet(ConfigsT &&...configs) {
+    (addConfig(std::forward<ConfigsT>(configs)), ...);
+  }
+
+  /// Get the configuration defined by the given type. Asserts that the
+  /// configuration of the provided type exists.
+  template <typename T>
+  const T &get() const {
+    const T *config = tryGet<T>();
+    assert(config && "configuration not found");
+    return *config;
+  }
+
+  /// Get the configuration defined by the given type, returns nullptr if the
+  /// configuration does not exist.
+  template <typename T>
+  const T *tryGet() const {
+    for (const auto &configIt : configs)
+      if (const T *config = dyn_cast<T>(configIt.get()))
+        return config;
+    return nullptr;
+  }
+
+  /// Notify the configurations within this set at the beginning or end of a
+  /// rewrite of a matched pattern.
+  void notifyRewriteBegin(PatternRewriter &rewriter) {
+    for (const auto &config : configs)
+      config->notifyRewriteBegin(rewriter);
+  }
+  void notifyRewriteEnd(PatternRewriter &rewriter) {
+    for (const auto &config : configs)
+      config->notifyRewriteEnd(rewriter);
+  }
+
+protected:
+  /// Add a configuration to the set.
+  template <typename T>
+  void addConfig(T &&config) {
+    assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
+    configs.emplace_back(
+        std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
+  }
+
+  /// The set of configurations for this pattern. This uses a vector instead of
+  /// a map with the expectation that the number of configurations per set is
+  /// small (<= 1).
+  SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
+};
+
+//===----------------------------------------------------------------------===//
+// PDLPatternModule
+
+/// A generic PDL pattern constraint function. This function applies a
+/// constraint to a given set of opaque PDLValue entities. Returns success if
+/// the constraint successfully held, failure otherwise.
+using PDLConstraintFunction =
+    std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
+/// A native PDL rewrite function. This function performs a rewrite on the
+/// given set of values. Any results from this rewrite that should be passed
+/// back to PDL should be added to the provided result list. This method is only
+/// invoked when the corresponding match was successful. Returns failure if an
+/// invariant of the rewrite was broken (certain rewriters may recover from
+/// partial pattern application).
+using PDLRewriteFunction = std::function<LogicalResult(
+    PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
+
+namespace detail {
+namespace pdl_function_builder {
+/// A utility variable that always resolves to false. This is useful for static
+/// asserts that are always false, but only should fire in certain templated
+/// constructs. For example, if a templated function should never be called, the
+/// function could be defined as:
+///
+/// template <typename T>
+/// void foo() {
+///  static_assert(always_false<T>, "This function should never be called");
+/// }
+///
+template <class... T>
+constexpr bool always_false = false;
+
+//===----------------------------------------------------------------------===//
+// PDL Function Builder: Type Processing
+//===----------------------------------------------------------------------===//
+
+/// This struct provides a convenient way to determine how to process a given
+/// type as either a PDL parameter, or a result value. This allows for
+/// supporting complex types in constraint and rewrite functions, without
+/// requiring the user to hand-write the necessary glue code themselves.
+/// Specializations of this class should implement the following methods to
+/// enable support as a PDL argument or result type:
+///
+///   static LogicalResult verifyAsArg(
+///     function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
+///     size_t argIdx);
+///
+///     * This method verifies that the given PDLValue is valid for use as a
+///       value of `T`.
+///
+///   static T processAsArg(PDLValue pdlValue);
+///
+///     *  This method processes the given PDLValue as a value of `T`.
+///
+///   static void processAsResult(PatternRewriter &, PDLResultList &results,
+///                               const T &value);
+///
+///     *  This method processes the given value of `T` as the result of a
+///        function invocation. The method should package the value into an
+///        appropriate form and append it to the given result list.
+///
+/// If the type `T` is based on a higher order value, consider using
+/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
+/// the implementation.
+///
+template <typename T, typename Enable = void>
+struct ProcessPDLValue;
+
+/// This struct provides a simplified model for processing types that are based
+/// on another type, e.g. APInt is based on the handling for IntegerAttr. This
+/// allows for building the necessary processing functions on top of the base
+/// value instead of a PDLValue. Derived users should implement the following
+/// (which subsume the ProcessPDLValue variants):
+///
+///   static LogicalResult verifyAsArg(
+///     function_ref<LogicalResult(const Twine &)> errorFn,
+///     const BaseT &baseValue, size_t argIdx);
+///
+///     * This method verifies that the given PDLValue is valid for use as a
+///       value of `T`.
+///
+///   static T processAsArg(BaseT baseValue);
+///
+///     *  This method processes the given base value as a value of `T`.
+///
+template <typename T, typename BaseT>
+struct ProcessPDLValueBasedOn {
+  static LogicalResult
+  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
+              PDLValue pdlValue, size_t argIdx) {
+    // Verify the base class before continuing.
+    if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
+      return failure();
+    return ProcessPDLValue<T>::verifyAsArg(
+        errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
+  }
+  static T processAsArg(PDLValue pdlValue) {
+    return ProcessPDLValue<T>::processAsArg(
+        ProcessPDLValue<BaseT>::processAsArg(pdlValue));
+  }
+
+  /// Explicitly add the expected parent API to ensure the parent class
+  /// implements the necessary API (and doesn't implicitly inherit it from
+  /// somewhere else).
+  static LogicalResult
+  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
+              size_t argIdx) {
+    return success();
+  }
+  static T processAsArg(BaseT baseValue);
+};
+
+/// ...
[truncated]

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look pretty good to me overall!

MLIR is a general infrastructure and as far as possible no dialect is intended to be special or privileged (we have Builtin which is both and should be addressed ...). Users should be able to use MLIR and the parts they need to customize for their solution (e.g., one doesn't have to include any dialect except the ones one uses).

Can you rework the description please? I feel it is misleading.
PDL/PDLL are dialects for the pattern rewrite infra, and what you're doing is making is possible for the "pattern rewrite infrastructure" to statically disable support for PD/PDLL. This does not make any dialect more "privileged", or "special". The fact that these are "dialects" is anecdotical: we could have written a different DSL with a custom bytecode interpret without using MLIR dialect just as well. The misleading part is that this is describing this in the say way as if these were dialects like the others, intended to blend in the compiler ecosystem.

@@ -0,0 +1,992 @@
//===- PDLPatternMatch.h - PDLPatternMatcher classes -------==---*- C++ -*-===//
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear to me why is this a .h.inc extension? Seems like a regular .h to me.

Isn't it standalone? If so can you point out why?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its not as it depends on data structures defined above it, but also it doesn't make sense to use in isolation as the functionality is not self-contained.


#if MLIR_ENABLE_PDL
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
Copy link
Collaborator

@joker-eph joker-eph Dec 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no include from there to the PDL dialect: where is the coupling?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are structures that only make sense with PDL, I want these all to be clearly stubbed out so that it returns empty arrays etc. Could probably leave in and hope DCE does its thing during compilation, but wanted to be more explicit.

Copy link

github-actions bot commented Dec 19, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@joker-eph
Copy link
Collaborator

Look pretty good to me overall!

MLIR is a general infrastructure and as far as possible no dialect is intended to be special or privileged (we have Builtin which is both and should be addressed ...). Users should be able to use MLIR and the parts they need to customize for their solution (e.g., one doesn't have to include any dialect except the ones one uses).

Can you rework the description please? I feel it is misleading. PDL/PDLL are dialects for the pattern rewrite infra, and what you're doing is making is possible for the "pattern rewrite infrastructure" to statically disable support for PD/PDLL. This does not make any dialect more "privileged", or "special". The fact that these are "dialects" is anecdotical: we could have written a different DSL with a custom bytecode interpret without using MLIR dialect just as well. The misleading part is that this is describing this in the say way as if these were dialects like the others, intended to blend in the compiler ecosystem.

Ping :)

@jpienaar
Copy link
Member Author

Look pretty good to me overall!

MLIR is a general infrastructure and as far as possible no dialect is intended to be special or privileged (we have Builtin which is both and should be addressed ...). Users should be able to use MLIR and the parts they need to customize for their solution (e.g., one doesn't have to include any dialect except the ones one uses).

Can you rework the description please? I feel it is misleading. PDL/PDLL are dialects for the pattern rewrite infra, and what you're doing is making is possible for the "pattern rewrite infrastructure" to statically disable support for PD/PDLL. This does not make any dialect more "privileged", or "special". The fact that these are "dialects" is anecdotical: we could have written a different DSL with a custom bytecode interpret without using MLIR dialect just as well. The misleading part is that this is describing this in the say way as if these were dialects like the others, intended to blend in the compiler ecosystem.

Ping :)

Done, simplified it to just that.

Make it so that PDL in pattern rewrites can be optionally disabled.

PDL is still enabled by default and not optional bazel. So this should
be a NOP for most folks, while enabling other to disable.

This only works with tests disabled. With tests enabled this still
compiles but tests fail as there is no lit config to disable tests that
depend on PDL rewrites yet.
@jpienaar jpienaar merged commit 5930725 into llvm:main Jan 3, 2024
makslevental added a commit that referenced this pull request Jan 3, 2024
jpienaar added a commit to jpienaar/llvm-project that referenced this pull request Jan 3, 2024
Make it so that PDL in pattern rewrites can be optionally disabled.

PDL is still enabled by default and not optional bazel. So this should
be a NOP for most folks, while enabling other to disable.

This is piped through mlir-tblgen invocation and that could be
changed/avoided by splitting up the passes file instead.

This only works with tests disabled. With tests enabled this still
compiles but tests fail as there is no lit config to disable tests that
depend on PDL rewrites yet.
jpienaar added a commit that referenced this pull request Jan 4, 2024
Make it so that PDL in pattern rewrites can be optionally disabled.

PDL is still enabled by default and not optional bazel. So this should
be a NOP for most folks, while enabling other to disable.

This only works with tests disabled. With tests enabled this still
compiles but tests fail as there is no lit config to disable tests that
depend on PDL rewrites yet.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:core MLIR Core Infrastructure mlir:llvm mlir:pdl mlir:python MLIR Python bindings mlir:vector mlir:vectorops mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants