Skip to content

[mlir][NFC] Move and rename EnumAttrCase, EnumAttr C++ classes #132650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 0 additions & 74 deletions mlir/include/mlir/TableGen/Attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Constraint.h"
#include "llvm/ADT/StringRef.h"

namespace llvm {
class DefInit;
Expand Down Expand Up @@ -136,79 +135,6 @@ class ConstantAttr {
const llvm::Record *def;
};

// Wrapper class providing helper methods for accessing enum attribute cases
// defined in TableGen. This is used for enum attribute case backed by both
// StringAttr and IntegerAttr.
class EnumAttrCase : public Attribute {
public:
explicit EnumAttrCase(const llvm::Record *record);
explicit EnumAttrCase(const llvm::DefInit *init);

// Returns the symbol of this enum attribute case.
StringRef getSymbol() const;

// Returns the textual representation of this enum attribute case.
StringRef getStr() const;

// Returns the value of this enum attribute case.
int64_t getValue() const;

// Returns the TableGen definition this EnumAttrCase was constructed from.
const llvm::Record &getDef() const;
};

// Wrapper class providing helper methods for accessing enum attributes defined
// in TableGen.This is used for enum attribute case backed by both StringAttr
// and IntegerAttr.
class EnumAttr : public Attribute {
public:
explicit EnumAttr(const llvm::Record *record);
explicit EnumAttr(const llvm::Record &record);
explicit EnumAttr(const llvm::DefInit *init);

static bool classof(const Attribute *attr);

// Returns true if this is a bit enum attribute.
bool isBitEnum() const;

// Returns the enum class name.
StringRef getEnumClassName() const;

// Returns the C++ namespaces this enum class should be placed in.
StringRef getCppNamespace() const;

// Returns the underlying type.
StringRef getUnderlyingType() const;

// Returns the name of the utility function that converts a value of the
// underlying type to the corresponding symbol.
StringRef getUnderlyingToSymbolFnName() const;

// Returns the name of the utility function that converts a string to the
// corresponding symbol.
StringRef getStringToSymbolFnName() const;

// Returns the name of the utility function that converts a symbol to the
// corresponding string.
StringRef getSymbolToStringFnName() const;

// Returns the return type of the utility function that converts a symbol to
// the corresponding string.
StringRef getSymbolToStringFnRetType() const;

// Returns the name of the utilit function that returns the max enum value
// used within the enum class.
StringRef getMaxEnumValFnName() const;

// Returns all allowed cases for this enum attribute.
std::vector<EnumAttrCase> getAllCases() const;

bool genSpecializedAttr() const;
const llvm::Record *getBaseAttrClass() const;
StringRef getSpecializedAttrClassName() const;
bool printBitEnumPrimaryGroups() const;
};

// Name of infer type op interface.
extern const char *inferTypeOpInterface;

Expand Down
133 changes: 133 additions & 0 deletions mlir/include/mlir/TableGen/EnumInfo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
//===- EnumInfo.h - EnumInfo wrapper class --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// EnumInfo wrapper to simplify using a TableGen Record defining an Enum
// via EnumInfo and its `EnumCase`s.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TABLEGEN_ENUMINFO_H_
#define MLIR_TABLEGEN_ENUMINFO_H_

#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Attribute.h"
#include "llvm/ADT/StringRef.h"

namespace llvm {
class DefInit;
class Record;
} // namespace llvm

namespace mlir::tblgen {

// Wrapper class providing around enum cases defined in TableGen.
class EnumCase {
public:
explicit EnumCase(const llvm::Record *record);
explicit EnumCase(const llvm::DefInit *init);

// Returns the symbol of this enum attribute case.
StringRef getSymbol() const;

// Returns the textual representation of this enum attribute case.
StringRef getStr() const;

// Returns the value of this enum attribute case.
int64_t getValue() const;

// Returns the TableGen definition this EnumAttrCase was constructed from.
const llvm::Record &getDef() const;

protected:
// The TableGen definition of this constraint.
const llvm::Record *def;
};

// Wrapper class providing helper methods for accessing enums defined
// in TableGen using EnumInfo. Some methods are only applicable when
// the enum is also an attribute, or only when it is a bit enum.
class EnumInfo {
public:
explicit EnumInfo(const llvm::Record *record);
explicit EnumInfo(const llvm::Record &record);
explicit EnumInfo(const llvm::DefInit *init);

// Returns true if the given EnumInfo is a subclass of the named TableGen
// class.
bool isSubClassOf(StringRef className) const;

// Returns true if this enum is an EnumAttrInfo, thus making it define an
// attribute.
bool isEnumAttr() const;

// Create the `Attribute` wrapper around this EnumInfo if it is defining an
// attribute.
std::optional<Attribute> asEnumAttr() const;

// Returns true if this is a bit enum.
bool isBitEnum() const;

// Returns the enum class name.
StringRef getEnumClassName() const;

// Returns the C++ namespaces this enum class should be placed in.
StringRef getCppNamespace() const;

// Returns the summary of the enum.
StringRef getSummary() const;

// Returns the description of the enum.
StringRef getDescription() const;

// Returns the underlying type.
StringRef getUnderlyingType() const;

// Returns the name of the utility function that converts a value of the
// underlying type to the corresponding symbol.
StringRef getUnderlyingToSymbolFnName() const;

// Returns the name of the utility function that converts a string to the
// corresponding symbol.
StringRef getStringToSymbolFnName() const;

// Returns the name of the utility function that converts a symbol to the
// corresponding string.
StringRef getSymbolToStringFnName() const;

// Returns the return type of the utility function that converts a symbol to
// the corresponding string.
StringRef getSymbolToStringFnRetType() const;

// Returns the name of the utilit function that returns the max enum value
// used within the enum class.
StringRef getMaxEnumValFnName() const;

// Returns all allowed cases for this enum attribute.
std::vector<EnumCase> getAllCases() const;

// Only applicable for enum attributes.

bool genSpecializedAttr() const;
const llvm::Record *getBaseAttrClass() const;
StringRef getSpecializedAttrClassName() const;

// Only applicable for bit enums.

bool printBitEnumPrimaryGroups() const;

// Returns the TableGen definition this EnumAttrCase was constructed from.
const llvm::Record &getDef() const;

protected:
// The TableGen definition of this constraint.
const llvm::Record *def;
};

} // namespace mlir::tblgen

#endif
11 changes: 6 additions & 5 deletions mlir/include/mlir/TableGen/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/EnumInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Hashing.h"
Expand Down Expand Up @@ -78,8 +79,8 @@ class DagLeaf {
// Returns true if this DAG leaf is specifying a constant attribute.
bool isConstantAttr() const;

// Returns true if this DAG leaf is specifying an enum attribute case.
bool isEnumAttrCase() const;
// Returns true if this DAG leaf is specifying an enum case.
bool isEnumCase() const;

// Returns true if this DAG leaf is specifying a string attribute.
bool isStringAttr() const;
Expand All @@ -90,9 +91,9 @@ class DagLeaf {
// Returns this DAG leaf as an constant attribute. Asserts if fails.
ConstantAttr getAsConstantAttr() const;

// Returns this DAG leaf as an enum attribute case.
// Precondition: isEnumAttrCase()
EnumAttrCase getAsEnumAttrCase() const;
// Returns this DAG leaf as an enum case.
// Precondition: isEnumCase()
EnumCase getAsEnumCase() const;

// Returns the matching condition template inside this DAG leaf. Assumes the
// leaf is an operand/attribute matcher and asserts otherwise.
Expand Down
94 changes: 0 additions & 94 deletions mlir/lib/TableGen/Attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,98 +146,4 @@ StringRef ConstantAttr::getConstantValue() const {
return def->getValueAsString("value");
}

EnumAttrCase::EnumAttrCase(const Record *record) : Attribute(record) {
assert(isSubClassOf("EnumAttrCaseInfo") &&
"must be subclass of TableGen 'EnumAttrInfo' class");
}

EnumAttrCase::EnumAttrCase(const DefInit *init)
: EnumAttrCase(init->getDef()) {}

StringRef EnumAttrCase::getSymbol() const {
return def->getValueAsString("symbol");
}

StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }

int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }

const Record &EnumAttrCase::getDef() const { return *def; }

EnumAttr::EnumAttr(const Record *record) : Attribute(record) {
assert(isSubClassOf("EnumAttrInfo") &&
"must be subclass of TableGen 'EnumAttr' class");
}

EnumAttr::EnumAttr(const Record &record) : Attribute(&record) {}

EnumAttr::EnumAttr(const DefInit *init) : EnumAttr(init->getDef()) {}

bool EnumAttr::classof(const Attribute *attr) {
return attr->isSubClassOf("EnumAttrInfo");
}

bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }

StringRef EnumAttr::getEnumClassName() const {
return def->getValueAsString("className");
}

StringRef EnumAttr::getCppNamespace() const {
return def->getValueAsString("cppNamespace");
}

StringRef EnumAttr::getUnderlyingType() const {
return def->getValueAsString("underlyingType");
}

StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
return def->getValueAsString("underlyingToSymbolFnName");
}

StringRef EnumAttr::getStringToSymbolFnName() const {
return def->getValueAsString("stringToSymbolFnName");
}

StringRef EnumAttr::getSymbolToStringFnName() const {
return def->getValueAsString("symbolToStringFnName");
}

StringRef EnumAttr::getSymbolToStringFnRetType() const {
return def->getValueAsString("symbolToStringFnRetType");
}

StringRef EnumAttr::getMaxEnumValFnName() const {
return def->getValueAsString("maxEnumValFnName");
}

std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
const auto *inits = def->getValueAsListInit("enumerants");

std::vector<EnumAttrCase> cases;
cases.reserve(inits->size());

for (const Init *init : *inits) {
cases.emplace_back(cast<DefInit>(init));
}

return cases;
}

bool EnumAttr::genSpecializedAttr() const {
return def->getValueAsBit("genSpecializedAttr");
}

const Record *EnumAttr::getBaseAttrClass() const {
return def->getValueAsDef("baseAttrClass");
}

StringRef EnumAttr::getSpecializedAttrClassName() const {
return def->getValueAsString("specializedAttrClassName");
}

bool EnumAttr::printBitEnumPrimaryGroups() const {
return def->getValueAsBit("printBitEnumPrimaryGroups");
}

const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";
1 change: 1 addition & 0 deletions mlir/lib/TableGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ llvm_add_library(MLIRTableGen STATIC
CodeGenHelpers.cpp
Constraint.cpp
Dialect.cpp
EnumInfo.cpp
Format.cpp
GenInfo.cpp
Interfaces.cpp
Expand Down
Loading