Skip to content

Commit c8b6e56

Browse files
committed
[mlir] Decouple enum generation from attributes, adding EnumInfo and EnumCase
This commit pulls apart the inherent attribute dependence of classes like EnumAttrInfo and EnumAttrCase, factoring them out into simpler EnumCase and EnumInfo variants. This allows specifying the cases of an enum without needing to make the cases, or the EnumInfo itself, a subclass of SignlessIntegerAttrBase. The existing classes are retained as subclasses of the new ones, both for backwards compatibility and to allow attribute-specific information. In addition, the new BitEnum class changes its default printer/parser behavior: cases when multiple keywords appear, like having both nuw and nsw in overflow flags, will no longer be quoted by the operator<<, and the FieldParser instance will now expect multiple keywords. All instances of BitEnumAttr retain the old behavior.
1 parent c9055e9 commit c8b6e56

File tree

21 files changed

+884
-516
lines changed

21 files changed

+884
-516
lines changed

mlir/include/mlir/IR/EnumAttr.td

Lines changed: 203 additions & 68 deletions
Large diffs are not rendered by default.

mlir/include/mlir/TableGen/Attribute.h

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
#include "mlir/Support/LLVM.h"
1818
#include "mlir/TableGen/Constraint.h"
19-
#include "llvm/ADT/StringRef.h"
2019

2120
namespace llvm {
2221
class DefInit;
@@ -136,79 +135,6 @@ class ConstantAttr {
136135
const llvm::Record *def;
137136
};
138137

139-
// Wrapper class providing helper methods for accessing enum attribute cases
140-
// defined in TableGen. This is used for enum attribute case backed by both
141-
// StringAttr and IntegerAttr.
142-
class EnumAttrCase : public Attribute {
143-
public:
144-
explicit EnumAttrCase(const llvm::Record *record);
145-
explicit EnumAttrCase(const llvm::DefInit *init);
146-
147-
// Returns the symbol of this enum attribute case.
148-
StringRef getSymbol() const;
149-
150-
// Returns the textual representation of this enum attribute case.
151-
StringRef getStr() const;
152-
153-
// Returns the value of this enum attribute case.
154-
int64_t getValue() const;
155-
156-
// Returns the TableGen definition this EnumAttrCase was constructed from.
157-
const llvm::Record &getDef() const;
158-
};
159-
160-
// Wrapper class providing helper methods for accessing enum attributes defined
161-
// in TableGen.This is used for enum attribute case backed by both StringAttr
162-
// and IntegerAttr.
163-
class EnumAttr : public Attribute {
164-
public:
165-
explicit EnumAttr(const llvm::Record *record);
166-
explicit EnumAttr(const llvm::Record &record);
167-
explicit EnumAttr(const llvm::DefInit *init);
168-
169-
static bool classof(const Attribute *attr);
170-
171-
// Returns true if this is a bit enum attribute.
172-
bool isBitEnum() const;
173-
174-
// Returns the enum class name.
175-
StringRef getEnumClassName() const;
176-
177-
// Returns the C++ namespaces this enum class should be placed in.
178-
StringRef getCppNamespace() const;
179-
180-
// Returns the underlying type.
181-
StringRef getUnderlyingType() const;
182-
183-
// Returns the name of the utility function that converts a value of the
184-
// underlying type to the corresponding symbol.
185-
StringRef getUnderlyingToSymbolFnName() const;
186-
187-
// Returns the name of the utility function that converts a string to the
188-
// corresponding symbol.
189-
StringRef getStringToSymbolFnName() const;
190-
191-
// Returns the name of the utility function that converts a symbol to the
192-
// corresponding string.
193-
StringRef getSymbolToStringFnName() const;
194-
195-
// Returns the return type of the utility function that converts a symbol to
196-
// the corresponding string.
197-
StringRef getSymbolToStringFnRetType() const;
198-
199-
// Returns the name of the utilit function that returns the max enum value
200-
// used within the enum class.
201-
StringRef getMaxEnumValFnName() const;
202-
203-
// Returns all allowed cases for this enum attribute.
204-
std::vector<EnumAttrCase> getAllCases() const;
205-
206-
bool genSpecializedAttr() const;
207-
const llvm::Record *getBaseAttrClass() const;
208-
StringRef getSpecializedAttrClassName() const;
209-
bool printBitEnumPrimaryGroups() const;
210-
};
211-
212138
// Name of infer type op interface.
213139
extern const char *inferTypeOpInterface;
214140

mlir/include/mlir/TableGen/EnumInfo.h

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
//===- EnumInfo.h - EnumInfo wrapper class --------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// EnumInfo wrapper to simplify using a TableGen Record defining an Enum
10+
// via EnumInfo and its `EnumCase`s.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_TABLEGEN_ENUMINFO_H_
15+
#define MLIR_TABLEGEN_ENUMINFO_H_
16+
17+
#include "mlir/Support/LLVM.h"
18+
#include "mlir/TableGen/Attribute.h"
19+
#include "llvm/ADT/StringRef.h"
20+
21+
namespace llvm {
22+
class DefInit;
23+
class Record;
24+
} // namespace llvm
25+
26+
namespace mlir {
27+
namespace tblgen {
28+
29+
// Wrapper class providing around enum cases defined in TableGen.
30+
class EnumCase {
31+
public:
32+
explicit EnumCase(const llvm::Record *record);
33+
explicit EnumCase(const llvm::DefInit *init);
34+
35+
// Returns the symbol of this enum attribute case.
36+
StringRef getSymbol() const;
37+
38+
// Returns the textual representation of this enum attribute case.
39+
StringRef getStr() const;
40+
41+
// Returns the value of this enum attribute case.
42+
int64_t getValue() const;
43+
44+
// Returns the TableGen definition this EnumAttrCase was constructed from.
45+
const llvm::Record &getDef() const;
46+
47+
protected:
48+
// The TableGen definition of this constraint.
49+
const llvm::Record *def;
50+
};
51+
52+
// Wrapper class providing helper methods for accessing enums defined
53+
// in TableGen using EnumInfo. Some methods are only applicable when
54+
// the enum is also an attribute, or only when it is a bit enum.
55+
class EnumInfo {
56+
public:
57+
explicit EnumInfo(const llvm::Record *record);
58+
explicit EnumInfo(const llvm::Record &record);
59+
explicit EnumInfo(const llvm::DefInit *init);
60+
61+
// Returns true if the given EnumInfo is a subclass of the named TableGen
62+
// class.
63+
bool isSubClassOf(StringRef className) const;
64+
65+
// Returns true if this enum is an EnumAttrInfo, thus making it define an
66+
// attribute.
67+
bool isEnumAttr() const;
68+
69+
// Create the `Attribute` wrapper around this EnumInfo if it is defining an
70+
// attribute.
71+
std::optional<Attribute> asEnumAttr() const;
72+
73+
// Returns true if this is a bit enum.
74+
bool isBitEnum() const;
75+
76+
// Returns the enum class name.
77+
StringRef getEnumClassName() const;
78+
79+
// Returns the C++ namespaces this enum class should be placed in.
80+
StringRef getCppNamespace() const;
81+
82+
// Returns the summary of the enum.
83+
StringRef getSummary() const;
84+
85+
// Returns the description of the enum.
86+
StringRef getDescription() const;
87+
88+
// Returns the bitwidth of the enum.
89+
int64_t getBitwidth() const;
90+
91+
// Returns the underlying type.
92+
StringRef getUnderlyingType() const;
93+
94+
// Returns the name of the utility function that converts a value of the
95+
// underlying type to the corresponding symbol.
96+
StringRef getUnderlyingToSymbolFnName() const;
97+
98+
// Returns the name of the utility function that converts a string to the
99+
// corresponding symbol.
100+
StringRef getStringToSymbolFnName() const;
101+
102+
// Returns the name of the utility function that converts a symbol to the
103+
// corresponding string.
104+
StringRef getSymbolToStringFnName() const;
105+
106+
// Returns the return type of the utility function that converts a symbol to
107+
// the corresponding string.
108+
StringRef getSymbolToStringFnRetType() const;
109+
110+
// Returns the name of the utilit function that returns the max enum value
111+
// used within the enum class.
112+
StringRef getMaxEnumValFnName() const;
113+
114+
// Returns all allowed cases for this enum attribute.
115+
std::vector<EnumCase> getAllCases() const;
116+
117+
// Only applicable for enum attributes.
118+
119+
bool genSpecializedAttr() const;
120+
const llvm::Record *getBaseAttrClass() const;
121+
StringRef getSpecializedAttrClassName() const;
122+
123+
// Only applicable for bit enums.
124+
125+
bool printBitEnumPrimaryGroups() const;
126+
bool printBitEnumQuoted() const;
127+
128+
// Returns the TableGen definition this EnumAttrCase was constructed from.
129+
const llvm::Record &getDef() const;
130+
131+
protected:
132+
// The TableGen definition of this constraint.
133+
const llvm::Record *def;
134+
};
135+
136+
} // namespace tblgen
137+
} // namespace mlir
138+
139+
#endif

mlir/include/mlir/TableGen/Pattern.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "mlir/Support/LLVM.h"
1818
#include "mlir/TableGen/Argument.h"
19+
#include "mlir/TableGen/EnumInfo.h"
1920
#include "mlir/TableGen/Operator.h"
2021
#include "llvm/ADT/DenseMap.h"
2122
#include "llvm/ADT/Hashing.h"
@@ -78,8 +79,8 @@ class DagLeaf {
7879
// Returns true if this DAG leaf is specifying a constant attribute.
7980
bool isConstantAttr() const;
8081

81-
// Returns true if this DAG leaf is specifying an enum attribute case.
82-
bool isEnumAttrCase() const;
82+
// Returns true if this DAG leaf is specifying an enum case.
83+
bool isEnumCase() const;
8384

8485
// Returns true if this DAG leaf is specifying a string attribute.
8586
bool isStringAttr() const;
@@ -90,9 +91,9 @@ class DagLeaf {
9091
// Returns this DAG leaf as an constant attribute. Asserts if fails.
9192
ConstantAttr getAsConstantAttr() const;
9293

93-
// Returns this DAG leaf as an enum attribute case.
94-
// Precondition: isEnumAttrCase()
95-
EnumAttrCase getAsEnumAttrCase() const;
94+
// Returns this DAG leaf as an enum case.
95+
// Precondition: isEnumCase()
96+
EnumCase getAsEnumCase() const;
9697

9798
// Returns the matching condition template inside this DAG leaf. Assumes the
9899
// leaf is an operand/attribute matcher and asserts otherwise.

mlir/lib/TableGen/Attribute.cpp

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -146,98 +146,4 @@ StringRef ConstantAttr::getConstantValue() const {
146146
return def->getValueAsString("value");
147147
}
148148

149-
EnumAttrCase::EnumAttrCase(const Record *record) : Attribute(record) {
150-
assert(isSubClassOf("EnumAttrCaseInfo") &&
151-
"must be subclass of TableGen 'EnumAttrInfo' class");
152-
}
153-
154-
EnumAttrCase::EnumAttrCase(const DefInit *init)
155-
: EnumAttrCase(init->getDef()) {}
156-
157-
StringRef EnumAttrCase::getSymbol() const {
158-
return def->getValueAsString("symbol");
159-
}
160-
161-
StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }
162-
163-
int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }
164-
165-
const Record &EnumAttrCase::getDef() const { return *def; }
166-
167-
EnumAttr::EnumAttr(const Record *record) : Attribute(record) {
168-
assert(isSubClassOf("EnumAttrInfo") &&
169-
"must be subclass of TableGen 'EnumAttr' class");
170-
}
171-
172-
EnumAttr::EnumAttr(const Record &record) : Attribute(&record) {}
173-
174-
EnumAttr::EnumAttr(const DefInit *init) : EnumAttr(init->getDef()) {}
175-
176-
bool EnumAttr::classof(const Attribute *attr) {
177-
return attr->isSubClassOf("EnumAttrInfo");
178-
}
179-
180-
bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
181-
182-
StringRef EnumAttr::getEnumClassName() const {
183-
return def->getValueAsString("className");
184-
}
185-
186-
StringRef EnumAttr::getCppNamespace() const {
187-
return def->getValueAsString("cppNamespace");
188-
}
189-
190-
StringRef EnumAttr::getUnderlyingType() const {
191-
return def->getValueAsString("underlyingType");
192-
}
193-
194-
StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
195-
return def->getValueAsString("underlyingToSymbolFnName");
196-
}
197-
198-
StringRef EnumAttr::getStringToSymbolFnName() const {
199-
return def->getValueAsString("stringToSymbolFnName");
200-
}
201-
202-
StringRef EnumAttr::getSymbolToStringFnName() const {
203-
return def->getValueAsString("symbolToStringFnName");
204-
}
205-
206-
StringRef EnumAttr::getSymbolToStringFnRetType() const {
207-
return def->getValueAsString("symbolToStringFnRetType");
208-
}
209-
210-
StringRef EnumAttr::getMaxEnumValFnName() const {
211-
return def->getValueAsString("maxEnumValFnName");
212-
}
213-
214-
std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
215-
const auto *inits = def->getValueAsListInit("enumerants");
216-
217-
std::vector<EnumAttrCase> cases;
218-
cases.reserve(inits->size());
219-
220-
for (const Init *init : *inits) {
221-
cases.emplace_back(cast<DefInit>(init));
222-
}
223-
224-
return cases;
225-
}
226-
227-
bool EnumAttr::genSpecializedAttr() const {
228-
return def->getValueAsBit("genSpecializedAttr");
229-
}
230-
231-
const Record *EnumAttr::getBaseAttrClass() const {
232-
return def->getValueAsDef("baseAttrClass");
233-
}
234-
235-
StringRef EnumAttr::getSpecializedAttrClassName() const {
236-
return def->getValueAsString("specializedAttrClassName");
237-
}
238-
239-
bool EnumAttr::printBitEnumPrimaryGroups() const {
240-
return def->getValueAsBit("printBitEnumPrimaryGroups");
241-
}
242-
243149
const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";

mlir/lib/TableGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ llvm_add_library(MLIRTableGen STATIC
2020
CodeGenHelpers.cpp
2121
Constraint.cpp
2222
Dialect.cpp
23+
EnumInfo.cpp
2324
Format.cpp
2425
GenInfo.cpp
2526
Interfaces.cpp

0 commit comments

Comments
 (0)