Skip to content

[mlir][NFC] Move and rename EnumAttrCase, EnumAttr C++ classes #132650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 27, 2025

Conversation

krzysz00
Copy link
Contributor

This moves the EnumAttrCase and EnumAttr classes from Attribute.h/.cpp to a new EnumInfo.h/cpp and renames them to EnumCase and EnumInfo, respectively.

This doesn't change any of the tablegen files or any user-facing aspects of the enum attribute generation system, just reorganizes code in order to make main PR (#132148) shorter.

This moves the EnumAttrCase and EnumAttr classes from Attribute.h/.cpp
to a new EnumInfo.h/cpp and renames them to EnumCase and EnumInfo,
respectively.

This doesn't change any of the tablegen files or any user-facing
aspects of the enum attribute generation system, just reorganizes code
in order to make main PR (#132148) shorter.
@llvmbot
Copy link
Member

llvmbot commented Mar 24, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir-core

Author: Krzysztof Drewniak (krzysz00)

Changes

This moves the EnumAttrCase and EnumAttr classes from Attribute.h/.cpp to a new EnumInfo.h/cpp and renames them to EnumCase and EnumInfo, respectively.

This doesn't change any of the tablegen files or any user-facing aspects of the enum attribute generation system, just reorganizes code in order to make main PR (#132148) shorter.


Patch is 74.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132650.diff

15 Files Affected:

  • (modified) mlir/include/mlir/TableGen/Attribute.h (-74)
  • (added) mlir/include/mlir/TableGen/EnumInfo.h (+135)
  • (modified) mlir/include/mlir/TableGen/Pattern.h (+6-5)
  • (modified) mlir/lib/TableGen/Attribute.cpp (-94)
  • (modified) mlir/lib/TableGen/CMakeLists.txt (+1)
  • (added) mlir/lib/TableGen/EnumInfo.cpp (+130)
  • (modified) mlir/lib/TableGen/Pattern.cpp (+5-7)
  • (modified) mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp (+26-21)
  • (modified) mlir/tools/mlir-tblgen/EnumsGen.cpp (+85-81)
  • (modified) mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp (+38-37)
  • (modified) mlir/tools/mlir-tblgen/OpDocGen.cpp (+9-8)
  • (modified) mlir/tools/mlir-tblgen/OpFormatGen.cpp (+18-17)
  • (modified) mlir/tools/mlir-tblgen/RewriterGen.cpp (+4-4)
  • (modified) mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp (+46-45)
  • (modified) mlir/tools/mlir-tblgen/TosaUtilsGen.cpp (+3-2)
diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index 62720e74849fc..dee81880bacab 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -16,7 +16,6 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Constraint.h"
-#include "llvm/ADT/StringRef.h"
 
 namespace llvm {
 class DefInit;
@@ -136,79 +135,6 @@ class ConstantAttr {
   const llvm::Record *def;
 };
 
-// Wrapper class providing helper methods for accessing enum attribute cases
-// defined in TableGen. This is used for enum attribute case backed by both
-// StringAttr and IntegerAttr.
-class EnumAttrCase : public Attribute {
-public:
-  explicit EnumAttrCase(const llvm::Record *record);
-  explicit EnumAttrCase(const llvm::DefInit *init);
-
-  // Returns the symbol of this enum attribute case.
-  StringRef getSymbol() const;
-
-  // Returns the textual representation of this enum attribute case.
-  StringRef getStr() const;
-
-  // Returns the value of this enum attribute case.
-  int64_t getValue() const;
-
-  // Returns the TableGen definition this EnumAttrCase was constructed from.
-  const llvm::Record &getDef() const;
-};
-
-// Wrapper class providing helper methods for accessing enum attributes defined
-// in TableGen.This is used for enum attribute case backed by both StringAttr
-// and IntegerAttr.
-class EnumAttr : public Attribute {
-public:
-  explicit EnumAttr(const llvm::Record *record);
-  explicit EnumAttr(const llvm::Record &record);
-  explicit EnumAttr(const llvm::DefInit *init);
-
-  static bool classof(const Attribute *attr);
-
-  // Returns true if this is a bit enum attribute.
-  bool isBitEnum() const;
-
-  // Returns the enum class name.
-  StringRef getEnumClassName() const;
-
-  // Returns the C++ namespaces this enum class should be placed in.
-  StringRef getCppNamespace() const;
-
-  // Returns the underlying type.
-  StringRef getUnderlyingType() const;
-
-  // Returns the name of the utility function that converts a value of the
-  // underlying type to the corresponding symbol.
-  StringRef getUnderlyingToSymbolFnName() const;
-
-  // Returns the name of the utility function that converts a string to the
-  // corresponding symbol.
-  StringRef getStringToSymbolFnName() const;
-
-  // Returns the name of the utility function that converts a symbol to the
-  // corresponding string.
-  StringRef getSymbolToStringFnName() const;
-
-  // Returns the return type of the utility function that converts a symbol to
-  // the corresponding string.
-  StringRef getSymbolToStringFnRetType() const;
-
-  // Returns the name of the utilit function that returns the max enum value
-  // used within the enum class.
-  StringRef getMaxEnumValFnName() const;
-
-  // Returns all allowed cases for this enum attribute.
-  std::vector<EnumAttrCase> getAllCases() const;
-
-  bool genSpecializedAttr() const;
-  const llvm::Record *getBaseAttrClass() const;
-  StringRef getSpecializedAttrClassName() const;
-  bool printBitEnumPrimaryGroups() const;
-};
-
 // Name of infer type op interface.
 extern const char *inferTypeOpInterface;
 
diff --git a/mlir/include/mlir/TableGen/EnumInfo.h b/mlir/include/mlir/TableGen/EnumInfo.h
new file mode 100644
index 0000000000000..196267864f325
--- /dev/null
+++ b/mlir/include/mlir/TableGen/EnumInfo.h
@@ -0,0 +1,135 @@
+//===- EnumInfo.h - EnumInfo wrapper class --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// EnumInfo wrapper to simplify using a TableGen Record defining an Enum
+// via EnumInfo and its `EnumCase`s.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_ENUMINFO_H_
+#define MLIR_TABLEGEN_ENUMINFO_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Attribute.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class DefInit;
+class Record;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+// Wrapper class providing around enum cases defined in TableGen.
+class EnumCase {
+public:
+  explicit EnumCase(const llvm::Record *record);
+  explicit EnumCase(const llvm::DefInit *init);
+
+  // Returns the symbol of this enum attribute case.
+  StringRef getSymbol() const;
+
+  // Returns the textual representation of this enum attribute case.
+  StringRef getStr() const;
+
+  // Returns the value of this enum attribute case.
+  int64_t getValue() const;
+
+  // Returns the TableGen definition this EnumAttrCase was constructed from.
+  const llvm::Record &getDef() const;
+
+protected:
+  // The TableGen definition of this constraint.
+  const llvm::Record *def;
+};
+
+// Wrapper class providing helper methods for accessing enums defined
+// in TableGen using EnumInfo. Some methods are only applicable when
+// the enum is also an attribute, or only when it is a bit enum.
+class EnumInfo {
+public:
+  explicit EnumInfo(const llvm::Record *record);
+  explicit EnumInfo(const llvm::Record &record);
+  explicit EnumInfo(const llvm::DefInit *init);
+
+  // Returns true if the given EnumInfo is a subclass of the named TableGen
+  // class.
+  bool isSubClassOf(StringRef className) const;
+
+  // Returns true if this enum is an EnumAttrInfo, thus making it define an
+  // attribute.
+  bool isEnumAttr() const;
+
+  // Create the `Attribute` wrapper around this EnumInfo if it is defining an
+  // attribute.
+  std::optional<Attribute> asEnumAttr() const;
+
+  // Returns true if this is a bit enum.
+  bool isBitEnum() const;
+
+  // Returns the enum class name.
+  StringRef getEnumClassName() const;
+
+  // Returns the C++ namespaces this enum class should be placed in.
+  StringRef getCppNamespace() const;
+
+  // Returns the summary of the enum.
+  StringRef getSummary() const;
+
+  // Returns the description of the enum.
+  StringRef getDescription() const;
+
+  // Returns the underlying type.
+  StringRef getUnderlyingType() const;
+
+  // Returns the name of the utility function that converts a value of the
+  // underlying type to the corresponding symbol.
+  StringRef getUnderlyingToSymbolFnName() const;
+
+  // Returns the name of the utility function that converts a string to the
+  // corresponding symbol.
+  StringRef getStringToSymbolFnName() const;
+
+  // Returns the name of the utility function that converts a symbol to the
+  // corresponding string.
+  StringRef getSymbolToStringFnName() const;
+
+  // Returns the return type of the utility function that converts a symbol to
+  // the corresponding string.
+  StringRef getSymbolToStringFnRetType() const;
+
+  // Returns the name of the utilit function that returns the max enum value
+  // used within the enum class.
+  StringRef getMaxEnumValFnName() const;
+
+  // Returns all allowed cases for this enum attribute.
+  std::vector<EnumCase> getAllCases() const;
+
+  // Only applicable for enum attributes.
+
+  bool genSpecializedAttr() const;
+  const llvm::Record *getBaseAttrClass() const;
+  StringRef getSpecializedAttrClassName() const;
+
+  // Only applicable for bit enums.
+
+  bool printBitEnumPrimaryGroups() const;
+
+  // Returns the TableGen definition this EnumAttrCase was constructed from.
+  const llvm::Record &getDef() const;
+
+protected:
+  // The TableGen definition of this constraint.
+  const llvm::Record *def;
+};
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 80f38fdeffee0..1c9e128f0a0fb 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Argument.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Hashing.h"
@@ -78,8 +79,8 @@ class DagLeaf {
   // Returns true if this DAG leaf is specifying a constant attribute.
   bool isConstantAttr() const;
 
-  // Returns true if this DAG leaf is specifying an enum attribute case.
-  bool isEnumAttrCase() const;
+  // Returns true if this DAG leaf is specifying an enum case.
+  bool isEnumCase() const;
 
   // Returns true if this DAG leaf is specifying a string attribute.
   bool isStringAttr() const;
@@ -90,9 +91,9 @@ class DagLeaf {
   // Returns this DAG leaf as an constant attribute. Asserts if fails.
   ConstantAttr getAsConstantAttr() const;
 
-  // Returns this DAG leaf as an enum attribute case.
-  // Precondition: isEnumAttrCase()
-  EnumAttrCase getAsEnumAttrCase() const;
+  // Returns this DAG leaf as an enum case.
+  // Precondition: isEnumCase()
+  EnumCase getAsEnumCase() const;
 
   // Returns the matching condition template inside this DAG leaf. Assumes the
   // leaf is an operand/attribute matcher and asserts otherwise.
diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index f9fc58a40f334..142d194260942 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -146,98 +146,4 @@ StringRef ConstantAttr::getConstantValue() const {
   return def->getValueAsString("value");
 }
 
-EnumAttrCase::EnumAttrCase(const Record *record) : Attribute(record) {
-  assert(isSubClassOf("EnumAttrCaseInfo") &&
-         "must be subclass of TableGen 'EnumAttrInfo' class");
-}
-
-EnumAttrCase::EnumAttrCase(const DefInit *init)
-    : EnumAttrCase(init->getDef()) {}
-
-StringRef EnumAttrCase::getSymbol() const {
-  return def->getValueAsString("symbol");
-}
-
-StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }
-
-int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }
-
-const Record &EnumAttrCase::getDef() const { return *def; }
-
-EnumAttr::EnumAttr(const Record *record) : Attribute(record) {
-  assert(isSubClassOf("EnumAttrInfo") &&
-         "must be subclass of TableGen 'EnumAttr' class");
-}
-
-EnumAttr::EnumAttr(const Record &record) : Attribute(&record) {}
-
-EnumAttr::EnumAttr(const DefInit *init) : EnumAttr(init->getDef()) {}
-
-bool EnumAttr::classof(const Attribute *attr) {
-  return attr->isSubClassOf("EnumAttrInfo");
-}
-
-bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
-
-StringRef EnumAttr::getEnumClassName() const {
-  return def->getValueAsString("className");
-}
-
-StringRef EnumAttr::getCppNamespace() const {
-  return def->getValueAsString("cppNamespace");
-}
-
-StringRef EnumAttr::getUnderlyingType() const {
-  return def->getValueAsString("underlyingType");
-}
-
-StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
-  return def->getValueAsString("underlyingToSymbolFnName");
-}
-
-StringRef EnumAttr::getStringToSymbolFnName() const {
-  return def->getValueAsString("stringToSymbolFnName");
-}
-
-StringRef EnumAttr::getSymbolToStringFnName() const {
-  return def->getValueAsString("symbolToStringFnName");
-}
-
-StringRef EnumAttr::getSymbolToStringFnRetType() const {
-  return def->getValueAsString("symbolToStringFnRetType");
-}
-
-StringRef EnumAttr::getMaxEnumValFnName() const {
-  return def->getValueAsString("maxEnumValFnName");
-}
-
-std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
-  const auto *inits = def->getValueAsListInit("enumerants");
-
-  std::vector<EnumAttrCase> cases;
-  cases.reserve(inits->size());
-
-  for (const Init *init : *inits) {
-    cases.emplace_back(cast<DefInit>(init));
-  }
-
-  return cases;
-}
-
-bool EnumAttr::genSpecializedAttr() const {
-  return def->getValueAsBit("genSpecializedAttr");
-}
-
-const Record *EnumAttr::getBaseAttrClass() const {
-  return def->getValueAsDef("baseAttrClass");
-}
-
-StringRef EnumAttr::getSpecializedAttrClassName() const {
-  return def->getValueAsString("specializedAttrClassName");
-}
-
-bool EnumAttr::printBitEnumPrimaryGroups() const {
-  return def->getValueAsBit("printBitEnumPrimaryGroups");
-}
-
 const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";
diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index c4104e644147c..a90c55847718e 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -20,6 +20,7 @@ llvm_add_library(MLIRTableGen STATIC
   CodeGenHelpers.cpp
   Constraint.cpp
   Dialect.cpp
+  EnumInfo.cpp
   Format.cpp
   GenInfo.cpp
   Interfaces.cpp
diff --git a/mlir/lib/TableGen/EnumInfo.cpp b/mlir/lib/TableGen/EnumInfo.cpp
new file mode 100644
index 0000000000000..9f491d30f0e7f
--- /dev/null
+++ b/mlir/lib/TableGen/EnumInfo.cpp
@@ -0,0 +1,130 @@
+//===- EnumInfo.cpp - EnumInfo wrapper class ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/EnumInfo.h"
+#include "mlir/TableGen/Attribute.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+using llvm::DefInit;
+using llvm::Init;
+using llvm::Record;
+
+EnumCase::EnumCase(const Record *record) : def(record) {
+  assert(def->isSubClassOf("EnumAttrCaseInfo") &&
+         "must be subclass of TableGen 'EnumAttrCaseInfo' class");
+}
+
+EnumCase::EnumCase(const DefInit *init) : EnumCase(init->getDef()) {}
+
+StringRef EnumCase::getSymbol() const {
+  return def->getValueAsString("symbol");
+}
+
+StringRef EnumCase::getStr() const { return def->getValueAsString("str"); }
+
+int64_t EnumCase::getValue() const { return def->getValueAsInt("value"); }
+
+const Record &EnumCase::getDef() const { return *def; }
+
+EnumInfo::EnumInfo(const Record *record) : def(record) {
+  assert(isSubClassOf("EnumAttrInfo") &&
+         "must be subclass of TableGen 'EnumAttrInfo' class");
+}
+
+EnumInfo::EnumInfo(const Record &record) : EnumInfo(&record) {}
+
+EnumInfo::EnumInfo(const DefInit *init) : EnumInfo(init->getDef()) {}
+
+bool EnumInfo::isSubClassOf(StringRef className) const {
+  return def->isSubClassOf(className);
+}
+
+bool EnumInfo::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
+
+std::optional<Attribute> EnumInfo::asEnumAttr() const {
+  if (isEnumAttr())
+    return Attribute(def);
+  return std::nullopt;
+}
+
+bool EnumInfo::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
+
+StringRef EnumInfo::getEnumClassName() const {
+  return def->getValueAsString("className");
+}
+
+StringRef EnumInfo::getSummary() const {
+  return def->getValueAsString("summary");
+}
+
+StringRef EnumInfo::getDescription() const {
+  return def->getValueAsString("description");
+}
+
+StringRef EnumInfo::getCppNamespace() const {
+  return def->getValueAsString("cppNamespace");
+}
+
+StringRef EnumInfo::getUnderlyingType() const {
+  return def->getValueAsString("underlyingType");
+}
+
+StringRef EnumInfo::getUnderlyingToSymbolFnName() const {
+  return def->getValueAsString("underlyingToSymbolFnName");
+}
+
+StringRef EnumInfo::getStringToSymbolFnName() const {
+  return def->getValueAsString("stringToSymbolFnName");
+}
+
+StringRef EnumInfo::getSymbolToStringFnName() const {
+  return def->getValueAsString("symbolToStringFnName");
+}
+
+StringRef EnumInfo::getSymbolToStringFnRetType() const {
+  return def->getValueAsString("symbolToStringFnRetType");
+}
+
+StringRef EnumInfo::getMaxEnumValFnName() const {
+  return def->getValueAsString("maxEnumValFnName");
+}
+
+std::vector<EnumCase> EnumInfo::getAllCases() const {
+  const auto *inits = def->getValueAsListInit("enumerants");
+
+  std::vector<EnumCase> cases;
+  cases.reserve(inits->size());
+
+  for (const Init *init : *inits) {
+    cases.emplace_back(cast<DefInit>(init));
+  }
+
+  return cases;
+}
+
+bool EnumInfo::genSpecializedAttr() const {
+  return isSubClassOf("EnumAttrInfo") &&
+         def->getValueAsBit("genSpecializedAttr");
+}
+
+const Record *EnumInfo::getBaseAttrClass() const {
+  return def->getValueAsDef("baseAttrClass");
+}
+
+StringRef EnumInfo::getSpecializedAttrClassName() const {
+  return def->getValueAsString("specializedAttrClassName");
+}
+
+bool EnumInfo::printBitEnumPrimaryGroups() const {
+  return def->getValueAsBit("printBitEnumPrimaryGroups");
+}
+
+const Record &EnumInfo::getDef() const { return *def; }
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index ac8c49c72d384..73e2803c21dae 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -57,9 +57,7 @@ bool DagLeaf::isNativeCodeCall() const {
 
 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
 
-bool DagLeaf::isEnumAttrCase() const {
-  return isSubClassOf("EnumAttrCaseInfo");
-}
+bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumAttrCaseInfo"); }
 
 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
 
@@ -74,9 +72,9 @@ ConstantAttr DagLeaf::getAsConstantAttr() const {
   return ConstantAttr(cast<DefInit>(def));
 }
 
-EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
-  assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
-  return EnumAttrCase(cast<DefInit>(def));
+EnumCase DagLeaf::getAsEnumCase() const {
+  assert(isEnumCase() && "the DAG leaf must be an enum attribute case");
+  return EnumCase(cast<DefInit>(def));
 }
 
 std::string DagLeaf::getConditionTemplate() const {
@@ -776,7 +774,7 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
         } else {
           auto constraint = leaf.getAsConstraint();
-          bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
+          bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() ||
                         leaf.isConstantAttr() ||
                         constraint.getKind() == Constraint::Kind::CK_Attr;
 
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index 3f660ae151c74..5d4d9e90fff67 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -15,6 +15,7 @@
 #include "mlir/TableGen/AttrOrTypeDef.h"
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/Dialect.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/TableGen/Record.h"
@@ -44,14 +45,14 @@ static std::string makePythonEnumCaseName(StringRef name) {
 }
 
 /// Emits the Python class for the given enum.
-static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
-  os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
-                enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
-  if (!enumAttr.getSummary().empty())
-    os << formatv("    \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
+static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) {
+  os << formatv("class {0}({1}):\n", enumInfo.getEnumClassName(),
+                enumInfo.isBitEnum() ? "IntFlag" : "IntEnum");
+  if (!enumInfo.getSummary().empty())
+    os << formatv("    \"\"\"{0}\"\"\"\n", enumInfo.getSummary());
   os << "\n";
 
-  for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
+  for (const EnumCase &enumCase : enumInfo.getAllCases()) {
     os << formatv("    {0} = {1}\n",
                   makePythonEnumCaseName(enumCase.getSymbol()),
                   enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
@@ -60,7 +61,7 @@ static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
 
   os << "\n";
 
-  if (enumAttr.isBitEnum()) {
+  if (enumInfo.isBitEnum()) {
     os << formatv("    def __iter__(self):\n"
                   "        return iter([case for case in type(self) if "
                   "(self & case) is case])\n");
@@ -70,17 +71,17 @@ static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
   }
 
   os << formatv("    def __str__(self):\n");
-  if (enumAttr.isBitEnum())
+  if (enumInfo.isBitEnum())
     os << formatv("        if len(self) > 1:\n"
                   "            return \"{0}\".join(map(str, self))\n",
-                  enumAttr.getDef().getValueAs...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 24, 2025

@llvm/pr-subscribers-mlir-spirv

Author: Krzysztof Drewniak (krzysz00)

Changes

This moves the EnumAttrCase and EnumAttr classes from Attribute.h/.cpp to a new EnumInfo.h/cpp and renames them to EnumCase and EnumInfo, respectively.

This doesn't change any of the tablegen files or any user-facing aspects of the enum attribute generation system, just reorganizes code in order to make main PR (#132148) shorter.


Patch is 74.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132650.diff

15 Files Affected:

  • (modified) mlir/include/mlir/TableGen/Attribute.h (-74)
  • (added) mlir/include/mlir/TableGen/EnumInfo.h (+135)
  • (modified) mlir/include/mlir/TableGen/Pattern.h (+6-5)
  • (modified) mlir/lib/TableGen/Attribute.cpp (-94)
  • (modified) mlir/lib/TableGen/CMakeLists.txt (+1)
  • (added) mlir/lib/TableGen/EnumInfo.cpp (+130)
  • (modified) mlir/lib/TableGen/Pattern.cpp (+5-7)
  • (modified) mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp (+26-21)
  • (modified) mlir/tools/mlir-tblgen/EnumsGen.cpp (+85-81)
  • (modified) mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp (+38-37)
  • (modified) mlir/tools/mlir-tblgen/OpDocGen.cpp (+9-8)
  • (modified) mlir/tools/mlir-tblgen/OpFormatGen.cpp (+18-17)
  • (modified) mlir/tools/mlir-tblgen/RewriterGen.cpp (+4-4)
  • (modified) mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp (+46-45)
  • (modified) mlir/tools/mlir-tblgen/TosaUtilsGen.cpp (+3-2)
diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index 62720e74849fc..dee81880bacab 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -16,7 +16,6 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Constraint.h"
-#include "llvm/ADT/StringRef.h"
 
 namespace llvm {
 class DefInit;
@@ -136,79 +135,6 @@ class ConstantAttr {
   const llvm::Record *def;
 };
 
-// Wrapper class providing helper methods for accessing enum attribute cases
-// defined in TableGen. This is used for enum attribute case backed by both
-// StringAttr and IntegerAttr.
-class EnumAttrCase : public Attribute {
-public:
-  explicit EnumAttrCase(const llvm::Record *record);
-  explicit EnumAttrCase(const llvm::DefInit *init);
-
-  // Returns the symbol of this enum attribute case.
-  StringRef getSymbol() const;
-
-  // Returns the textual representation of this enum attribute case.
-  StringRef getStr() const;
-
-  // Returns the value of this enum attribute case.
-  int64_t getValue() const;
-
-  // Returns the TableGen definition this EnumAttrCase was constructed from.
-  const llvm::Record &getDef() const;
-};
-
-// Wrapper class providing helper methods for accessing enum attributes defined
-// in TableGen.This is used for enum attribute case backed by both StringAttr
-// and IntegerAttr.
-class EnumAttr : public Attribute {
-public:
-  explicit EnumAttr(const llvm::Record *record);
-  explicit EnumAttr(const llvm::Record &record);
-  explicit EnumAttr(const llvm::DefInit *init);
-
-  static bool classof(const Attribute *attr);
-
-  // Returns true if this is a bit enum attribute.
-  bool isBitEnum() const;
-
-  // Returns the enum class name.
-  StringRef getEnumClassName() const;
-
-  // Returns the C++ namespaces this enum class should be placed in.
-  StringRef getCppNamespace() const;
-
-  // Returns the underlying type.
-  StringRef getUnderlyingType() const;
-
-  // Returns the name of the utility function that converts a value of the
-  // underlying type to the corresponding symbol.
-  StringRef getUnderlyingToSymbolFnName() const;
-
-  // Returns the name of the utility function that converts a string to the
-  // corresponding symbol.
-  StringRef getStringToSymbolFnName() const;
-
-  // Returns the name of the utility function that converts a symbol to the
-  // corresponding string.
-  StringRef getSymbolToStringFnName() const;
-
-  // Returns the return type of the utility function that converts a symbol to
-  // the corresponding string.
-  StringRef getSymbolToStringFnRetType() const;
-
-  // Returns the name of the utilit function that returns the max enum value
-  // used within the enum class.
-  StringRef getMaxEnumValFnName() const;
-
-  // Returns all allowed cases for this enum attribute.
-  std::vector<EnumAttrCase> getAllCases() const;
-
-  bool genSpecializedAttr() const;
-  const llvm::Record *getBaseAttrClass() const;
-  StringRef getSpecializedAttrClassName() const;
-  bool printBitEnumPrimaryGroups() const;
-};
-
 // Name of infer type op interface.
 extern const char *inferTypeOpInterface;
 
diff --git a/mlir/include/mlir/TableGen/EnumInfo.h b/mlir/include/mlir/TableGen/EnumInfo.h
new file mode 100644
index 0000000000000..196267864f325
--- /dev/null
+++ b/mlir/include/mlir/TableGen/EnumInfo.h
@@ -0,0 +1,135 @@
+//===- EnumInfo.h - EnumInfo wrapper class --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// EnumInfo wrapper to simplify using a TableGen Record defining an Enum
+// via EnumInfo and its `EnumCase`s.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_ENUMINFO_H_
+#define MLIR_TABLEGEN_ENUMINFO_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Attribute.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class DefInit;
+class Record;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+// Wrapper class providing around enum cases defined in TableGen.
+class EnumCase {
+public:
+  explicit EnumCase(const llvm::Record *record);
+  explicit EnumCase(const llvm::DefInit *init);
+
+  // Returns the symbol of this enum attribute case.
+  StringRef getSymbol() const;
+
+  // Returns the textual representation of this enum attribute case.
+  StringRef getStr() const;
+
+  // Returns the value of this enum attribute case.
+  int64_t getValue() const;
+
+  // Returns the TableGen definition this EnumAttrCase was constructed from.
+  const llvm::Record &getDef() const;
+
+protected:
+  // The TableGen definition of this constraint.
+  const llvm::Record *def;
+};
+
+// Wrapper class providing helper methods for accessing enums defined
+// in TableGen using EnumInfo. Some methods are only applicable when
+// the enum is also an attribute, or only when it is a bit enum.
+class EnumInfo {
+public:
+  explicit EnumInfo(const llvm::Record *record);
+  explicit EnumInfo(const llvm::Record &record);
+  explicit EnumInfo(const llvm::DefInit *init);
+
+  // Returns true if the given EnumInfo is a subclass of the named TableGen
+  // class.
+  bool isSubClassOf(StringRef className) const;
+
+  // Returns true if this enum is an EnumAttrInfo, thus making it define an
+  // attribute.
+  bool isEnumAttr() const;
+
+  // Create the `Attribute` wrapper around this EnumInfo if it is defining an
+  // attribute.
+  std::optional<Attribute> asEnumAttr() const;
+
+  // Returns true if this is a bit enum.
+  bool isBitEnum() const;
+
+  // Returns the enum class name.
+  StringRef getEnumClassName() const;
+
+  // Returns the C++ namespaces this enum class should be placed in.
+  StringRef getCppNamespace() const;
+
+  // Returns the summary of the enum.
+  StringRef getSummary() const;
+
+  // Returns the description of the enum.
+  StringRef getDescription() const;
+
+  // Returns the underlying type.
+  StringRef getUnderlyingType() const;
+
+  // Returns the name of the utility function that converts a value of the
+  // underlying type to the corresponding symbol.
+  StringRef getUnderlyingToSymbolFnName() const;
+
+  // Returns the name of the utility function that converts a string to the
+  // corresponding symbol.
+  StringRef getStringToSymbolFnName() const;
+
+  // Returns the name of the utility function that converts a symbol to the
+  // corresponding string.
+  StringRef getSymbolToStringFnName() const;
+
+  // Returns the return type of the utility function that converts a symbol to
+  // the corresponding string.
+  StringRef getSymbolToStringFnRetType() const;
+
+  // Returns the name of the utilit function that returns the max enum value
+  // used within the enum class.
+  StringRef getMaxEnumValFnName() const;
+
+  // Returns all allowed cases for this enum attribute.
+  std::vector<EnumCase> getAllCases() const;
+
+  // Only applicable for enum attributes.
+
+  bool genSpecializedAttr() const;
+  const llvm::Record *getBaseAttrClass() const;
+  StringRef getSpecializedAttrClassName() const;
+
+  // Only applicable for bit enums.
+
+  bool printBitEnumPrimaryGroups() const;
+
+  // Returns the TableGen definition this EnumAttrCase was constructed from.
+  const llvm::Record &getDef() const;
+
+protected:
+  // The TableGen definition of this constraint.
+  const llvm::Record *def;
+};
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 80f38fdeffee0..1c9e128f0a0fb 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Argument.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Hashing.h"
@@ -78,8 +79,8 @@ class DagLeaf {
   // Returns true if this DAG leaf is specifying a constant attribute.
   bool isConstantAttr() const;
 
-  // Returns true if this DAG leaf is specifying an enum attribute case.
-  bool isEnumAttrCase() const;
+  // Returns true if this DAG leaf is specifying an enum case.
+  bool isEnumCase() const;
 
   // Returns true if this DAG leaf is specifying a string attribute.
   bool isStringAttr() const;
@@ -90,9 +91,9 @@ class DagLeaf {
   // Returns this DAG leaf as an constant attribute. Asserts if fails.
   ConstantAttr getAsConstantAttr() const;
 
-  // Returns this DAG leaf as an enum attribute case.
-  // Precondition: isEnumAttrCase()
-  EnumAttrCase getAsEnumAttrCase() const;
+  // Returns this DAG leaf as an enum case.
+  // Precondition: isEnumCase()
+  EnumCase getAsEnumCase() const;
 
   // Returns the matching condition template inside this DAG leaf. Assumes the
   // leaf is an operand/attribute matcher and asserts otherwise.
diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index f9fc58a40f334..142d194260942 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -146,98 +146,4 @@ StringRef ConstantAttr::getConstantValue() const {
   return def->getValueAsString("value");
 }
 
-EnumAttrCase::EnumAttrCase(const Record *record) : Attribute(record) {
-  assert(isSubClassOf("EnumAttrCaseInfo") &&
-         "must be subclass of TableGen 'EnumAttrInfo' class");
-}
-
-EnumAttrCase::EnumAttrCase(const DefInit *init)
-    : EnumAttrCase(init->getDef()) {}
-
-StringRef EnumAttrCase::getSymbol() const {
-  return def->getValueAsString("symbol");
-}
-
-StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }
-
-int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }
-
-const Record &EnumAttrCase::getDef() const { return *def; }
-
-EnumAttr::EnumAttr(const Record *record) : Attribute(record) {
-  assert(isSubClassOf("EnumAttrInfo") &&
-         "must be subclass of TableGen 'EnumAttr' class");
-}
-
-EnumAttr::EnumAttr(const Record &record) : Attribute(&record) {}
-
-EnumAttr::EnumAttr(const DefInit *init) : EnumAttr(init->getDef()) {}
-
-bool EnumAttr::classof(const Attribute *attr) {
-  return attr->isSubClassOf("EnumAttrInfo");
-}
-
-bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
-
-StringRef EnumAttr::getEnumClassName() const {
-  return def->getValueAsString("className");
-}
-
-StringRef EnumAttr::getCppNamespace() const {
-  return def->getValueAsString("cppNamespace");
-}
-
-StringRef EnumAttr::getUnderlyingType() const {
-  return def->getValueAsString("underlyingType");
-}
-
-StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
-  return def->getValueAsString("underlyingToSymbolFnName");
-}
-
-StringRef EnumAttr::getStringToSymbolFnName() const {
-  return def->getValueAsString("stringToSymbolFnName");
-}
-
-StringRef EnumAttr::getSymbolToStringFnName() const {
-  return def->getValueAsString("symbolToStringFnName");
-}
-
-StringRef EnumAttr::getSymbolToStringFnRetType() const {
-  return def->getValueAsString("symbolToStringFnRetType");
-}
-
-StringRef EnumAttr::getMaxEnumValFnName() const {
-  return def->getValueAsString("maxEnumValFnName");
-}
-
-std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
-  const auto *inits = def->getValueAsListInit("enumerants");
-
-  std::vector<EnumAttrCase> cases;
-  cases.reserve(inits->size());
-
-  for (const Init *init : *inits) {
-    cases.emplace_back(cast<DefInit>(init));
-  }
-
-  return cases;
-}
-
-bool EnumAttr::genSpecializedAttr() const {
-  return def->getValueAsBit("genSpecializedAttr");
-}
-
-const Record *EnumAttr::getBaseAttrClass() const {
-  return def->getValueAsDef("baseAttrClass");
-}
-
-StringRef EnumAttr::getSpecializedAttrClassName() const {
-  return def->getValueAsString("specializedAttrClassName");
-}
-
-bool EnumAttr::printBitEnumPrimaryGroups() const {
-  return def->getValueAsBit("printBitEnumPrimaryGroups");
-}
-
 const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";
diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index c4104e644147c..a90c55847718e 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -20,6 +20,7 @@ llvm_add_library(MLIRTableGen STATIC
   CodeGenHelpers.cpp
   Constraint.cpp
   Dialect.cpp
+  EnumInfo.cpp
   Format.cpp
   GenInfo.cpp
   Interfaces.cpp
diff --git a/mlir/lib/TableGen/EnumInfo.cpp b/mlir/lib/TableGen/EnumInfo.cpp
new file mode 100644
index 0000000000000..9f491d30f0e7f
--- /dev/null
+++ b/mlir/lib/TableGen/EnumInfo.cpp
@@ -0,0 +1,130 @@
+//===- EnumInfo.cpp - EnumInfo wrapper class ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/EnumInfo.h"
+#include "mlir/TableGen/Attribute.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+using llvm::DefInit;
+using llvm::Init;
+using llvm::Record;
+
+EnumCase::EnumCase(const Record *record) : def(record) {
+  assert(def->isSubClassOf("EnumAttrCaseInfo") &&
+         "must be subclass of TableGen 'EnumAttrCaseInfo' class");
+}
+
+EnumCase::EnumCase(const DefInit *init) : EnumCase(init->getDef()) {}
+
+StringRef EnumCase::getSymbol() const {
+  return def->getValueAsString("symbol");
+}
+
+StringRef EnumCase::getStr() const { return def->getValueAsString("str"); }
+
+int64_t EnumCase::getValue() const { return def->getValueAsInt("value"); }
+
+const Record &EnumCase::getDef() const { return *def; }
+
+EnumInfo::EnumInfo(const Record *record) : def(record) {
+  assert(isSubClassOf("EnumAttrInfo") &&
+         "must be subclass of TableGen 'EnumAttrInfo' class");
+}
+
+EnumInfo::EnumInfo(const Record &record) : EnumInfo(&record) {}
+
+EnumInfo::EnumInfo(const DefInit *init) : EnumInfo(init->getDef()) {}
+
+bool EnumInfo::isSubClassOf(StringRef className) const {
+  return def->isSubClassOf(className);
+}
+
+bool EnumInfo::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
+
+std::optional<Attribute> EnumInfo::asEnumAttr() const {
+  if (isEnumAttr())
+    return Attribute(def);
+  return std::nullopt;
+}
+
+bool EnumInfo::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
+
+StringRef EnumInfo::getEnumClassName() const {
+  return def->getValueAsString("className");
+}
+
+StringRef EnumInfo::getSummary() const {
+  return def->getValueAsString("summary");
+}
+
+StringRef EnumInfo::getDescription() const {
+  return def->getValueAsString("description");
+}
+
+StringRef EnumInfo::getCppNamespace() const {
+  return def->getValueAsString("cppNamespace");
+}
+
+StringRef EnumInfo::getUnderlyingType() const {
+  return def->getValueAsString("underlyingType");
+}
+
+StringRef EnumInfo::getUnderlyingToSymbolFnName() const {
+  return def->getValueAsString("underlyingToSymbolFnName");
+}
+
+StringRef EnumInfo::getStringToSymbolFnName() const {
+  return def->getValueAsString("stringToSymbolFnName");
+}
+
+StringRef EnumInfo::getSymbolToStringFnName() const {
+  return def->getValueAsString("symbolToStringFnName");
+}
+
+StringRef EnumInfo::getSymbolToStringFnRetType() const {
+  return def->getValueAsString("symbolToStringFnRetType");
+}
+
+StringRef EnumInfo::getMaxEnumValFnName() const {
+  return def->getValueAsString("maxEnumValFnName");
+}
+
+std::vector<EnumCase> EnumInfo::getAllCases() const {
+  const auto *inits = def->getValueAsListInit("enumerants");
+
+  std::vector<EnumCase> cases;
+  cases.reserve(inits->size());
+
+  for (const Init *init : *inits) {
+    cases.emplace_back(cast<DefInit>(init));
+  }
+
+  return cases;
+}
+
+bool EnumInfo::genSpecializedAttr() const {
+  return isSubClassOf("EnumAttrInfo") &&
+         def->getValueAsBit("genSpecializedAttr");
+}
+
+const Record *EnumInfo::getBaseAttrClass() const {
+  return def->getValueAsDef("baseAttrClass");
+}
+
+StringRef EnumInfo::getSpecializedAttrClassName() const {
+  return def->getValueAsString("specializedAttrClassName");
+}
+
+bool EnumInfo::printBitEnumPrimaryGroups() const {
+  return def->getValueAsBit("printBitEnumPrimaryGroups");
+}
+
+const Record &EnumInfo::getDef() const { return *def; }
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index ac8c49c72d384..73e2803c21dae 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -57,9 +57,7 @@ bool DagLeaf::isNativeCodeCall() const {
 
 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
 
-bool DagLeaf::isEnumAttrCase() const {
-  return isSubClassOf("EnumAttrCaseInfo");
-}
+bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumAttrCaseInfo"); }
 
 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
 
@@ -74,9 +72,9 @@ ConstantAttr DagLeaf::getAsConstantAttr() const {
   return ConstantAttr(cast<DefInit>(def));
 }
 
-EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
-  assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
-  return EnumAttrCase(cast<DefInit>(def));
+EnumCase DagLeaf::getAsEnumCase() const {
+  assert(isEnumCase() && "the DAG leaf must be an enum attribute case");
+  return EnumCase(cast<DefInit>(def));
 }
 
 std::string DagLeaf::getConditionTemplate() const {
@@ -776,7 +774,7 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
         } else {
           auto constraint = leaf.getAsConstraint();
-          bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
+          bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() ||
                         leaf.isConstantAttr() ||
                         constraint.getKind() == Constraint::Kind::CK_Attr;
 
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index 3f660ae151c74..5d4d9e90fff67 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -15,6 +15,7 @@
 #include "mlir/TableGen/AttrOrTypeDef.h"
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/Dialect.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/TableGen/Record.h"
@@ -44,14 +45,14 @@ static std::string makePythonEnumCaseName(StringRef name) {
 }
 
 /// Emits the Python class for the given enum.
-static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
-  os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
-                enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
-  if (!enumAttr.getSummary().empty())
-    os << formatv("    \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
+static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) {
+  os << formatv("class {0}({1}):\n", enumInfo.getEnumClassName(),
+                enumInfo.isBitEnum() ? "IntFlag" : "IntEnum");
+  if (!enumInfo.getSummary().empty())
+    os << formatv("    \"\"\"{0}\"\"\"\n", enumInfo.getSummary());
   os << "\n";
 
-  for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
+  for (const EnumCase &enumCase : enumInfo.getAllCases()) {
     os << formatv("    {0} = {1}\n",
                   makePythonEnumCaseName(enumCase.getSymbol()),
                   enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
@@ -60,7 +61,7 @@ static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
 
   os << "\n";
 
-  if (enumAttr.isBitEnum()) {
+  if (enumInfo.isBitEnum()) {
     os << formatv("    def __iter__(self):\n"
                   "        return iter([case for case in type(self) if "
                   "(self & case) is case])\n");
@@ -70,17 +71,17 @@ static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
   }
 
   os << formatv("    def __str__(self):\n");
-  if (enumAttr.isBitEnum())
+  if (enumInfo.isBitEnum())
     os << formatv("        if len(self) > 1:\n"
                   "            return \"{0}\".join(map(str, self))\n",
-                  enumAttr.getDef().getValueAs...
[truncated]

Comment on lines 26 to 27
namespace mlir {
namespace tblgen {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can combine these

Suggested change
namespace mlir {
namespace tblgen {
namespace mlir::tblgen {

@krzysz00 krzysz00 merged commit 263ec72 into main Mar 27, 2025
11 checks passed
@krzysz00 krzysz00 deleted the users/krzysz00/move-enum-info branch March 27, 2025 01:26
@GleasonK
Copy link
Contributor

Something about this series of changes broke our enum docs:
https://github.com/openxla/stablehlo/blob/848b4a1a6033022d09646bb2144ee3f91552cec4/stablehlo/dialect/ChloEnums.td#L52-L71

image

I'll try to investigate / add test / fix next week, but wanted to ping in case anythings obviously wrong with our enums

@GleasonK
Copy link
Contributor

Ah - False alarm, it actually just fixed some redundancy! I'm seeing that the enum values are still listed below in the actual enum (not enum attr) section. Apologies for the noise, and thanks for the improvement :)!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants