Skip to content

Commit fdc496a

Browse files
committed
[mlir] EnumsGen: dissociate string form of integer enum from C++ symbol name
Summary: In some cases, one may want to use different names for C++ symbol of an enumerand from its string representation. In particular, in the LLVM dialect for, e.g., Linkage, we would like to preserve the same enumerand names as LLVM API and the same textual IR form as LLVM IR, yet the two are different (CamelCase vs snake_case with additional limitations on not being a C++ keyword). Modify EnumAttrCaseInfo in OpBase.td to include both the integer value and its string representation. By default, this representation is the same as C++ symbol name. Introduce new IntStrAttrCaseBase that allows one to use different names. Exercise it for LLVM Dialect Linkage attribute. Other attributes will follow as separate changes. Differential Revision: https://reviews.llvm.org/D73362
1 parent 3bbe7a6 commit fdc496a

File tree

9 files changed

+115
-88
lines changed

9 files changed

+115
-88
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,15 @@ class LLVM_Op<string mnemonic, list<OpTrait> traits = []> :
6262
class LLVM_IntrOp<string mnemonic, list<OpTrait> traits = []> :
6363
LLVM_Op<"intr."#mnemonic, traits>;
6464

65+
// Case of the LLVM enum attribute backed by I64Attr with customized string
66+
// representation that corresponds to what is visible in the textual IR form.
67+
class LLVM_EnumAttrCase<string cppSym, string irSym, int val> :
68+
I64EnumAttrCase<cppSym, val, irSym>;
69+
70+
// LLVM enum attribute backed by I64Attr with string representation
71+
// corresponding to what is visible in the textual IR form.
72+
class LLVM_EnumAttr<string name, string description,
73+
list<LLVM_EnumAttrCase> cases> :
74+
I64EnumAttr<name, description, cases>;
75+
6576
#endif // LLVMIR_OP_BASE

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -494,18 +494,21 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
494494
// https://llvm.org/docs/LangRef.html#linkage-types. The names are equivalent to
495495
// visible names in the IR rather than to enum values names in llvm::GlobalValue
496496
// since the latter is easier to change.
497-
def LinkagePrivate : I64EnumAttrCase<"Private", 0>;
498-
def LinkageInternal : I64EnumAttrCase<"Internal", 1>;
499-
def LinkageAvailableExternally : I64EnumAttrCase<"AvailableExternally", 2>;
500-
def LinkageLinkonce : I64EnumAttrCase<"Linkonce", 3>;
501-
def LinkageWeak : I64EnumAttrCase<"Weak", 4>;
502-
def LinkageCommon : I64EnumAttrCase<"Common", 5>;
503-
def LinkageAppending : I64EnumAttrCase<"Appending", 6>;
504-
def LinkageExternWeak : I64EnumAttrCase<"ExternWeak", 7>;
505-
def LinkageLinkonceODR : I64EnumAttrCase<"LinkonceODR", 8>;
506-
def LinkageWeakODR : I64EnumAttrCase<"WeakODR", 9>;
507-
def LinkageExternal : I64EnumAttrCase<"External", 10>;
508-
def Linkage : I64EnumAttr<
497+
def LinkagePrivate : LLVM_EnumAttrCase<"Private", "private", 0>;
498+
def LinkageInternal : LLVM_EnumAttrCase<"Internal", "internal", 1>;
499+
def LinkageAvailableExternally : LLVM_EnumAttrCase<"AvailableExternally",
500+
"available_externally", 2>;
501+
def LinkageLinkonce : LLVM_EnumAttrCase<"Linkonce", "linkonce", 3>;
502+
def LinkageWeak : LLVM_EnumAttrCase<"Weak", "weak", 4>;
503+
def LinkageCommon : LLVM_EnumAttrCase<"Common", "common", 5>;
504+
def LinkageAppending : LLVM_EnumAttrCase<"Appending", "appending", 6>;
505+
def LinkageExternWeak : LLVM_EnumAttrCase<"ExternWeak",
506+
"extern_weak", 7>;
507+
def LinkageLinkonceODR : LLVM_EnumAttrCase<"LinkonceODR",
508+
"linkonce_odr", 8>;
509+
def LinkageWeakODR : LLVM_EnumAttrCase<"WeakODR", "weak_odr", 9>;
510+
def LinkageExternal : LLVM_EnumAttrCase<"External", "external", 10>;
511+
def Linkage : LLVM_EnumAttr<
509512
"Linkage",
510513
"LLVM linkage types",
511514
[LinkagePrivate, LinkageInternal, LinkageAvailableExternally,

mlir/include/mlir/IR/OpBase.td

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -808,39 +808,47 @@ def UnitAttr : Attr<CPred<"$_self.isa<UnitAttr>()">, "unit attribute"> {
808808
// Enum attribute kinds
809809

810810
// Additional information for an enum attribute case.
811-
class EnumAttrCaseInfo<string sym, int val> {
812-
// The C++ enumerant symbol
811+
class EnumAttrCaseInfo<string sym, int intVal, string strVal> {
812+
// The C++ enumerant symbol.
813813
string symbol = sym;
814814

815-
// The C++ enumerant value
815+
// The C++ enumerant value.
816816
// If less than zero, there will be no explicit discriminator values assigned
817817
// to enumerators in the generated enum class.
818-
int value = val;
818+
int value = intVal;
819+
820+
// The string representation of the enumerant. May be the same as symbol.
821+
string str = strVal;
819822
}
820823

821824
// An enum attribute case stored with StringAttr.
822825
class StrEnumAttrCase<string sym, int val = -1> :
823-
EnumAttrCaseInfo<sym, val>,
826+
EnumAttrCaseInfo<sym, val, sym>,
824827
StringBasedAttr<
825828
CPred<"$_self.cast<StringAttr>().getValue() == \"" # sym # "\"">,
826829
"case " # sym>;
827830

828-
// An enum attribute case stored with IntegerAttr.
829-
class IntEnumAttrCaseBase<I intType, string sym, int val> :
830-
EnumAttrCaseInfo<sym, val>,
831-
IntegerAttrBase<intType, "case " # sym> {
831+
// An enum attribute case stored with IntegerAttr, which has an integer value,
832+
// its representation as a string and a C++ symbol name which may be different.
833+
class IntEnumAttrCaseBase<I intType, string sym, string strVal, int intVal> :
834+
EnumAttrCaseInfo<sym, intVal, strVal>,
835+
IntegerAttrBase<intType, "case " # strVal> {
832836
let predicate =
833-
CPred<"$_self.cast<IntegerAttr>().getInt() == " # val>;
837+
CPred<"$_self.cast<IntegerAttr>().getInt() == " # intVal>;
834838
}
835839

836-
class I32EnumAttrCase<string sym, int val> : IntEnumAttrCaseBase<I32, sym, val>;
837-
class I64EnumAttrCase<string sym, int val> : IntEnumAttrCaseBase<I64, sym, val>;
840+
// Cases of integer enum attributes with a specific type. By default, the string
841+
// representation is the same as the C++ symbol name.
842+
class I32EnumAttrCase<string sym, int val, string str = sym>
843+
: IntEnumAttrCaseBase<I32, sym, str, val>;
844+
class I64EnumAttrCase<string sym, int val, string str = sym>
845+
: IntEnumAttrCaseBase<I64, sym, str, val>;
838846

839847
// A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the
840848
// ordinal number of the bit that is set. It is the 32-bit integer with only
841849
// one bit set.
842850
class BitEnumAttrCase<string sym, int val> :
843-
EnumAttrCaseInfo<sym, val>,
851+
EnumAttrCaseInfo<sym, val, sym>,
844852
IntegerAttrBase<I32, "case " # sym> {
845853
let predicate = CPred<
846854
"$_self.cast<IntegerAttr>().getValue().getZExtValue() & " # val # "u">;

mlir/include/mlir/TableGen/Attribute.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ class EnumAttrCase : public Attribute {
134134
// Returns the symbol of this enum attribute case.
135135
StringRef getSymbol() const;
136136

137+
// Returns the textual representation of this enum attribute case.
138+
StringRef getStr() const;
139+
137140
// Returns the value of this enum attribute case.
138141
int64_t getValue() const;
139142

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 33 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,44 +1079,8 @@ void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type,
10791079
result.addRegion();
10801080
}
10811081

1082-
// Returns the textual representation of the given linkage.
1083-
static StringRef linkageToStr(LLVM::Linkage linkage) {
1084-
switch (linkage) {
1085-
case LLVM::Linkage::Private:
1086-
return "private";
1087-
case LLVM::Linkage::Internal:
1088-
return "internal";
1089-
case LLVM::Linkage::AvailableExternally:
1090-
return "available_externally";
1091-
case LLVM::Linkage::Linkonce:
1092-
return "linkonce";
1093-
case LLVM::Linkage::Weak:
1094-
return "weak";
1095-
case LLVM::Linkage::Common:
1096-
return "common";
1097-
case LLVM::Linkage::Appending:
1098-
return "appending";
1099-
case LLVM::Linkage::ExternWeak:
1100-
return "extern_weak";
1101-
case LLVM::Linkage::LinkonceODR:
1102-
return "linkonce_odr";
1103-
case LLVM::Linkage::WeakODR:
1104-
return "weak_odr";
1105-
case LLVM::Linkage::External:
1106-
return "external";
1107-
}
1108-
llvm_unreachable("unknown linkage type");
1109-
}
1110-
1111-
// Prints the keyword for the linkage type using the printer.
1112-
static void printLinkage(OpAsmPrinter &p, LLVM::Linkage linkage) {
1113-
p << linkageToStr(linkage);
1114-
}
1115-
11161082
static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
1117-
p << op.getOperationName() << ' ';
1118-
printLinkage(p, op.linkage());
1119-
p << ' ';
1083+
p << op.getOperationName() << ' ' << stringifyLinkage(op.linkage()) << ' ';
11201084
if (op.constant())
11211085
p << "constant ";
11221086
p.printSymbolName(op.sym_name());
@@ -1150,22 +1114,30 @@ static int parseOptionalKeywordAlternative(OpAsmParser &parser,
11501114
return -1;
11511115
}
11521116

1153-
// Parses one of the linkage keywords and, if succeeded, appends the "linkage"
1154-
// integer attribute with the corresponding value to `result`.
1155-
//
1156-
// linkage ::= `private` | `internal` | `available_externally` | `linkonce`
1157-
// | `weak` | `common` | `appending` | `extern_weak`
1158-
// | `linkonce_odr` | `weak_odr` | `external
1159-
static ParseResult parseOptionalLinkageKeyword(OpAsmParser &parser,
1160-
OperationState &result) {
1161-
int index = parseOptionalKeywordAlternative(
1162-
parser, {"private", "internal", "available_externally", "linkonce",
1163-
"weak", "common", "appending", "extern_weak", "linkonce_odr",
1164-
"weak_odr", "external"});
1117+
namespace {
1118+
template <typename Ty> struct EnumTraits {};
1119+
1120+
#define REGISTER_ENUM_TYPE(Ty) \
1121+
template <> struct EnumTraits<Ty> { \
1122+
static StringRef stringify(Ty value) { return stringify##Ty(value); } \
1123+
static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
1124+
}
1125+
1126+
REGISTER_ENUM_TYPE(Linkage);
1127+
} // end namespace
1128+
1129+
template <typename EnumTy>
1130+
static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser,
1131+
OperationState &result,
1132+
StringRef name) {
1133+
SmallVector<StringRef, 10> names;
1134+
for (unsigned i = 0, e = getMaxEnumValForLinkage(); i <= e; ++i)
1135+
names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
1136+
1137+
int index = parseOptionalKeywordAlternative(parser, names);
11651138
if (index == -1)
11661139
return failure();
1167-
result.addAttribute(getLinkageAttrName(),
1168-
parser.getBuilder().getI64IntegerAttr(index));
1140+
result.addAttribute(name, parser.getBuilder().getI64IntegerAttr(index));
11691141
return success();
11701142
}
11711143

@@ -1175,7 +1147,8 @@ static ParseResult parseOptionalLinkageKeyword(OpAsmParser &parser,
11751147
// The type can be omitted for string attributes, in which case it will be
11761148
// inferred from the value of the string as [strlen(value) x i8].
11771149
static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
1178-
if (failed(parseOptionalLinkageKeyword(parser, result)))
1150+
if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
1151+
getLinkageAttrName())))
11791152
return parser.emitError(parser.getCurrentLocation(), "expected linkage");
11801153

11811154
if (succeeded(parser.parseOptionalKeyword("constant")))
@@ -1398,7 +1371,8 @@ static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
13981371
static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
13991372
OperationState &result) {
14001373
// Default to external linkage if no keyword is provided.
1401-
if (failed(parseOptionalLinkageKeyword(parser, result)))
1374+
if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
1375+
getLinkageAttrName())))
14021376
result.addAttribute(getLinkageAttrName(),
14031377
parser.getBuilder().getI64IntegerAttr(
14041378
static_cast<int64_t>(LLVM::Linkage::External)));
@@ -1441,10 +1415,8 @@ static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
14411415
// the external linkage since it is the default value.
14421416
static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
14431417
p << op.getOperationName() << ' ';
1444-
if (op.linkage() != LLVM::Linkage::External) {
1445-
printLinkage(p, op.linkage());
1446-
p << ' ';
1447-
}
1418+
if (op.linkage() != LLVM::Linkage::External)
1419+
p << stringifyLinkage(op.linkage()) << ' ';
14481420
p.printSymbolName(op.getName());
14491421

14501422
LLVMType fnType = op.getType();
@@ -1510,16 +1482,16 @@ unsigned LLVMFuncOp::getNumFuncResults() {
15101482
static LogicalResult verify(LLVMFuncOp op) {
15111483
if (op.linkage() == LLVM::Linkage::Common)
15121484
return op.emitOpError()
1513-
<< "functions cannot have '" << linkageToStr(LLVM::Linkage::Common)
1514-
<< "' linkage";
1485+
<< "functions cannot have '"
1486+
<< stringifyLinkage(LLVM::Linkage::Common) << "' linkage";
15151487

15161488
if (op.isExternal()) {
15171489
if (op.linkage() != LLVM::Linkage::External &&
15181490
op.linkage() != LLVM::Linkage::ExternWeak)
15191491
return op.emitOpError()
15201492
<< "external functions must have '"
1521-
<< linkageToStr(LLVM::Linkage::External) << "' or '"
1522-
<< linkageToStr(LLVM::Linkage::ExternWeak) << "' linkage";
1493+
<< stringifyLinkage(LLVM::Linkage::External) << "' or '"
1494+
<< stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage";
15231495
return success();
15241496
}
15251497

mlir/lib/TableGen/Attribute.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ StringRef tblgen::EnumAttrCase::getSymbol() const {
154154
return def->getValueAsString("symbol");
155155
}
156156

157+
StringRef tblgen::EnumAttrCase::getStr() const {
158+
return def->getValueAsString("str");
159+
}
160+
157161
int64_t tblgen::EnumAttrCase::getValue() const {
158162
return def->getValueAsInt("value");
159163
}

mlir/tools/mlir-tblgen/EnumsGen.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,9 @@ static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
165165
os << " switch (val) {\n";
166166
for (const auto &enumerant : enumerants) {
167167
auto symbol = enumerant.getSymbol();
168+
auto str = enumerant.getStr();
168169
os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName,
169-
makeIdentifier(symbol), symbol);
170+
makeIdentifier(symbol), str);
170171
}
171172
os << " }\n";
172173
os << " return \"\";\n";
@@ -219,7 +220,8 @@ static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
219220
enumName);
220221
for (const auto &enumerant : enumerants) {
221222
auto symbol = enumerant.getSymbol();
222-
os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, symbol,
223+
auto str = enumerant.getStr();
224+
os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, str,
223225
makeIdentifier(symbol));
224226
}
225227
os << " .Default(llvm::None);\n";

mlir/unittests/TableGen/EnumsGenTest.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,21 @@ TEST(EnumsGenTest, GeneratedOperator) {
9494
EXPECT_FALSE(bitEnumContains(BitEnumWithNone::Bit1 & BitEnumWithNone::Bit3,
9595
BitEnumWithNone::Bit1));
9696
}
97+
98+
TEST(EnumsGenTest, GeneratedSymbolToCustomStringFn) {
99+
EXPECT_EQ(stringifyPrettyIntEnum(PrettyIntEnum::Case1), "case_one");
100+
EXPECT_EQ(stringifyPrettyIntEnum(PrettyIntEnum::Case2), "case_two");
101+
}
102+
103+
TEST(EnumsGenTest, GeneratedCustomStringToSymbolFn) {
104+
auto one = symbolizePrettyIntEnum("case_one");
105+
EXPECT_TRUE(one);
106+
EXPECT_EQ(*one, PrettyIntEnum::Case1);
107+
108+
auto two = symbolizePrettyIntEnum("case_two");
109+
EXPECT_TRUE(two);
110+
EXPECT_EQ(*two, PrettyIntEnum::Case2);
111+
112+
auto none = symbolizePrettyIntEnum("Case1");
113+
EXPECT_FALSE(none);
114+
}

mlir/unittests/TableGen/enums.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,9 @@ def BitEnumWithNone : BitEnumAttr<"BitEnumWithNone", "A test enum",
3131

3232
def BitEnumWithoutNone : BitEnumAttr<"BitEnumWithoutNone", "A test enum",
3333
[Bit1, Bit3]>;
34+
35+
def PrettyIntEnumCase1: I32EnumAttrCase<"Case1", 1, "case_one">;
36+
def PrettyIntEnumCase2: I32EnumAttrCase<"Case2", 2, "case_two">;
37+
38+
def PrettyIntEnum: I32EnumAttr<"PrettyIntEnum", "A test enum",
39+
[PrettyIntEnumCase1, PrettyIntEnumCase2]>;

0 commit comments

Comments
 (0)