Skip to content

Commit 17d981b

Browse files
committed
[mlir] Improve EnumProp, making it take an EnumInfo
This commit improves the `EnumProp` class, causing it to wrap around an `EnumInfo` just like` EnumAttr` does. This EnumProp also has logic for converting to/from an integer attribute and for being read and written as bitcode. The following variants of `EnumProp` are provided: - `EnumPropWithAttrForm` - an EnumProp that can be constructed from (and will be converted to, if `storeInCustomAttribute` is true) a custom attribute, like an `EnumAttr`, instead of a plain integer. This is meant for backwards compatibility with code that uses enum attributes. `NamedEnumProp` adds a "`mnemonic` `<` $enum `>`" syntax around the enum, replicating a common pattern seen in MLIR printers and allowing for reduced ambiguity. `NamedEnumPropWithAttrForm` combines both of these extensions. (Sadly, bitcode auto-upgrade is hampered by the lack of the ability to optionally parse an attribute.)
1 parent c8b6e56 commit 17d981b

File tree

8 files changed

+313
-92
lines changed

8 files changed

+313
-92
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -485,17 +485,16 @@ def DISubprogramFlags : I32BitEnumAttr<
485485
// IntegerOverflowFlags
486486
//===----------------------------------------------------------------------===//
487487

488-
def IOFnone : I32BitEnumAttrCaseNone<"none">;
489-
def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
490-
def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;
488+
def IOFnone : I32BitEnumCaseNone<"none">;
489+
def IOFnsw : I32BitEnumCaseBit<"nsw", 0>;
490+
def IOFnuw : I32BitEnumCaseBit<"nuw", 1>;
491491

492-
def IntegerOverflowFlags : I32BitEnumAttr<
492+
def IntegerOverflowFlags : I32BitEnum<
493493
"IntegerOverflowFlags",
494494
"LLVM integer overflow flags",
495495
[IOFnone, IOFnsw, IOFnuw]> {
496496
let separator = ", ";
497497
let cppNamespace = "::mlir::LLVM";
498-
let genSpecializedAttr = 0;
499498
let printBitEnumPrimaryGroups = 1;
500499
}
501500

@@ -504,6 +503,11 @@ def LLVM_IntegerOverflowFlagsAttr :
504503
let assemblyFormat = "`<` $value `>`";
505504
}
506505

506+
def LLVM_IntegerOverflowFlagsProp :
507+
NamedEnumPropWithAttrForm<IntegerOverflowFlags, "overflow", LLVM_IntegerOverflowFlagsAttr> {
508+
let defaultValue = enum.cppType # "::" # "none";
509+
}
510+
507511
//===----------------------------------------------------------------------===//
508512
// FastmathFlags
509513
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
6060
list<Trait> traits = []> :
6161
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
6262
!listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
63-
dag iofArg = (ins EnumProp<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags);
63+
dag iofArg = (ins LLVM_IntegerOverflowFlagsProp:$overflowFlags);
6464
let arguments = !con(commonArgs, iofArg);
6565

6666
string mlirBuilder = [{
@@ -69,7 +69,7 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
6969
$res = op;
7070
}];
7171
let assemblyFormat = [{
72-
$lhs `,` $rhs `` custom<OverflowFlags>($overflowFlags) attr-dict `:` type($res)
72+
$lhs `,` $rhs ($overflowFlags^)? attr-dict `:` type($res)
7373
}];
7474
string llvmBuilder =
7575
"$res = builder.Create" # instName #
@@ -563,10 +563,10 @@ class LLVM_CastOpWithOverflowFlag<string mnemonic, string instName, Type type,
563563
Type resultType, list<Trait> traits = []> :
564564
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)>,
565565
LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());"> {
566-
let arguments = (ins type:$arg, EnumProp<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags);
566+
let arguments = (ins type:$arg, LLVM_IntegerOverflowFlagsProp:$overflowFlags);
567567
let results = (outs resultType:$res);
568568
let builders = [LLVM_OneResultOpBuilder];
569-
let assemblyFormat = "$arg `` custom<OverflowFlags>($overflowFlags) attr-dict `:` type($arg) `to` type($res)";
569+
let assemblyFormat = "$arg ($overflowFlags^)? attr-dict `:` type($arg) `to` type($res)";
570570
string llvmInstName = instName;
571571
string mlirBuilder = [{
572572
auto op = $_builder.create<$_qualCppClassName>(

mlir/include/mlir/IR/EnumAttr.td

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define ENUMATTR_TD
1111

1212
include "mlir/IR/AttrTypeBase.td"
13+
include "mlir/IR/Properties.td"
1314

1415
//===----------------------------------------------------------------------===//
1516
// Enum attribute kinds
@@ -551,6 +552,139 @@ class EnumAttr<Dialect dialect, EnumInfo enumInfo, string name = "",
551552
let assemblyFormat = "$value";
552553
}
553554

555+
// A property wrapping by a C++ enum. This class will automatically create bytecode
556+
// serialization logic for the given enum, as well as arranging for parser and
557+
// printer calls.
558+
class EnumProp<EnumInfo enumInfo> : Property<enumInfo.cppType, enumInfo.summary> {
559+
EnumInfo enum = enumInfo;
560+
561+
let description = enum.description;
562+
let predicate = !if(
563+
!isa<BitEnumBase>(enum),
564+
CPred<"(static_cast<" # enum.underlyingType # ">($_self) & ~" # !cast<BitEnumBase>(enum).validBits # ") == 0">,
565+
Or<!foreach(case, enum.enumerants, CPred<"$_self == " # enum.cppType # "::" # case.symbol>)>);
566+
567+
let convertFromAttribute = [{
568+
auto intAttr = ::mlir::dyn_cast_if_present<::mlir::IntegerAttr>($_attr);
569+
if (!intAttr) {
570+
return $_diag() << "expected IntegerAttr storage for }] #
571+
enum.cppType # [{";
572+
}
573+
$_storage = static_cast<}] # enum.cppType # [{>(intAttr.getValue().getZExtValue());
574+
return ::mlir::success();
575+
}];
576+
577+
let convertToAttribute = [{
578+
return ::mlir::IntegerAttr::get(::mlir::IntegerType::get($_ctxt, }] # enum.bitwidth
579+
# [{), static_cast<}] # enum.underlyingType #[{>($_storage));
580+
}];
581+
582+
let writeToMlirBytecode = [{
583+
$_writer.writeVarInt(static_cast<uint64_t>($_storage));
584+
}];
585+
586+
let readFromMlirBytecode = [{
587+
uint64_t rawValue;
588+
if (::mlir::failed($_reader.readVarInt(rawValue)))
589+
return ::mlir::failure();
590+
$_storage = static_cast<}] # enum.cppType # [{>(rawValue);
591+
}];
592+
593+
let optionalParser = [{
594+
auto value = ::mlir::FieldParser<std::optional<}] # enum.cppType # [{>>::parse($_parser);
595+
if (::mlir::failed(value))
596+
return ::mlir::failure();
597+
if (!(value->has_value()))
598+
return std::nullopt;
599+
$_storage = std::move(**value);
600+
}];
601+
}
602+
603+
// Enum property that can have been (or, if `storeInCustomAttribute` is true, will also
604+
// be stored as) an attribute, in addition to being stored as an integer attribute.
605+
class EnumPropWithAttrForm<EnumInfo enumInfo, Attr attributeForm>
606+
: EnumProp<enumInfo> {
607+
Attr attrForm = attributeForm;
608+
bit storeInCustomAttribute = 0;
609+
610+
let convertFromAttribute = [{
611+
auto customAttr = ::mlir::dyn_cast_if_present<}]
612+
# attrForm.storageType # [{>($_attr);
613+
if (customAttr) {
614+
$_storage = customAttr.getValue();
615+
return ::mlir::success();
616+
}
617+
auto intAttr = ::mlir::dyn_cast_if_present<::mlir::IntegerAttr>($_attr);
618+
if (!intAttr) {
619+
return $_diag() << "expected }] # attrForm.storageType
620+
# [{ or IntegerAttr storage for }] # enum.cppType # [{";
621+
}
622+
$_storage = static_cast<}] # enum.cppType # [{>(intAttr.getValue().getZExtValue());
623+
return ::mlir::success();
624+
}];
625+
626+
let convertToAttribute = !if(storeInCustomAttribute, [{
627+
return }] # attrForm.storageType # [{::get($_ctxt, $_storage);
628+
}], [{
629+
return ::mlir::IntegerAttr::get(::mlir::IntegerType::get($_ctxt, }] # enumInfo.bitwidth
630+
# [{), static_cast<}] # enum.underlyingType #[{>($_storage));
631+
}]);
632+
}
633+
634+
class _namedEnumPropFields<string cppType, string mnemonic> {
635+
code parser = [{
636+
if ($_parser.parseKeyword("}] # mnemonic # [{")
637+
|| $_parser.parseLess()) {
638+
return ::mlir::failure();
639+
}
640+
auto parseRes = ::mlir::FieldParser<}] # cppType # [{>::parse($_parser);
641+
if (::mlir::failed(parseRes) ||
642+
::mlir::failed($_parser.parseGreater())) {
643+
return ::mlir::failure();
644+
}
645+
$_storage = *parseRes;
646+
}];
647+
648+
code optionalParser = [{
649+
if ($_parser.parseOptionalKeyword("}] # mnemonic # [{")) {
650+
return std::nullopt;
651+
}
652+
if ($_parser.parseLess()) {
653+
return ::mlir::failure();
654+
}
655+
auto parseRes = ::mlir::FieldParser<}] # cppType # [{>::parse($_parser);
656+
if (::mlir::failed(parseRes) ||
657+
::mlir::failed($_parser.parseGreater())) {
658+
return ::mlir::failure();
659+
}
660+
$_storage = *parseRes;
661+
}];
662+
663+
code printer = [{
664+
$_printer << "}] # mnemonic # [{<" << $_storage << ">";
665+
}];
666+
}
667+
668+
// An EnumProp which, when printed, is surrounded by mnemonic<>.
669+
// For example, if the enum can be a, b, or c, and the mnemonic is foo,
670+
// the format of this property will be "foo<a>", "foo<b>", or "foo<c>".
671+
class NamedEnumProp<EnumInfo enumInfo, string name>
672+
: EnumProp<enumInfo> {
673+
string mnemonic = name;
674+
let parser = _namedEnumPropFields<enum.cppType, mnemonic>.parser;
675+
let optionalParser = _namedEnumPropFields<enum.cppType, mnemonic>.optionalParser;
676+
let printer = _namedEnumPropFields<enum.cppType, mnemonic>.printer;
677+
}
678+
679+
// A `NamedEnumProp` with an attribute form as in `EnumPropWithAttrForm`.
680+
class NamedEnumPropWithAttrForm<EnumInfo enumInfo, string name, Attr attributeForm>
681+
: EnumPropWithAttrForm<enumInfo, attributeForm> {
682+
string mnemonic = name;
683+
let parser = _namedEnumPropFields<enum.cppType, mnemonic>.parser;
684+
let optionalParser = _namedEnumPropFields<enum.cppType, mnemonic>.optionalParser;
685+
let printer = _namedEnumPropFields<enumInfo.cppType, mnemonic>.printer;
686+
}
687+
554688
class _symbolToValue<EnumInfo enumInfo, string case> {
555689
defvar cases =
556690
!filter(iter, enumInfo.enumerants, !eq(iter.str, case));

mlir/include/mlir/IR/Properties.td

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -238,25 +238,6 @@ def I64Prop : IntProp<"int64_t">;
238238
def I32Property : IntProp<"int32_t">, Deprecated<"moved to shorter name I32Prop">;
239239
def I64Property : IntProp<"int64_t">, Deprecated<"moved to shorter name I64Prop">;
240240

241-
class EnumProp<string storageTypeParam, string desc = "", string default = ""> :
242-
Property<storageTypeParam, desc> {
243-
// TODO: implement predicate for enum validity.
244-
let writeToMlirBytecode = [{
245-
$_writer.writeVarInt(static_cast<uint64_t>($_storage));
246-
}];
247-
let readFromMlirBytecode = [{
248-
uint64_t val;
249-
if (failed($_reader.readVarInt(val)))
250-
return ::mlir::failure();
251-
$_storage = static_cast<}] # storageTypeParam # [{>(val);
252-
}];
253-
let defaultValue = default;
254-
}
255-
256-
class EnumProperty<string storageTypeParam, string desc = "", string default = ""> :
257-
EnumProp<storageTypeParam, desc, default>,
258-
Deprecated<"moved to shorter name EnumProp">;
259-
260241
// Note: only a class so we can deprecate the old name
261242
class _cls_StringProp : Property<"std::string", "string"> {
262243
let interfaceType = "::llvm::StringRef";

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -49,70 +49,6 @@ using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;
4949

5050
#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
5151

52-
//===----------------------------------------------------------------------===//
53-
// Property Helpers
54-
//===----------------------------------------------------------------------===//
55-
56-
//===----------------------------------------------------------------------===//
57-
// IntegerOverflowFlags
58-
59-
namespace mlir {
60-
static Attribute convertToAttribute(MLIRContext *ctx,
61-
IntegerOverflowFlags flags) {
62-
return IntegerOverflowFlagsAttr::get(ctx, flags);
63-
}
64-
65-
static LogicalResult
66-
convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr,
67-
function_ref<InFlightDiagnostic()> emitError) {
68-
auto flagsAttr = dyn_cast<IntegerOverflowFlagsAttr>(attr);
69-
if (!flagsAttr) {
70-
return emitError() << "expected 'overflowFlags' attribute to be an "
71-
"IntegerOverflowFlagsAttr, but got "
72-
<< attr;
73-
}
74-
flags = flagsAttr.getValue();
75-
return success();
76-
}
77-
} // namespace mlir
78-
79-
static ParseResult parseOverflowFlags(AsmParser &p,
80-
IntegerOverflowFlags &flags) {
81-
if (failed(p.parseOptionalKeyword("overflow"))) {
82-
flags = IntegerOverflowFlags::none;
83-
return success();
84-
}
85-
if (p.parseLess())
86-
return failure();
87-
do {
88-
StringRef kw;
89-
SMLoc loc = p.getCurrentLocation();
90-
if (p.parseKeyword(&kw))
91-
return failure();
92-
std::optional<IntegerOverflowFlags> flag =
93-
symbolizeIntegerOverflowFlags(kw);
94-
if (!flag)
95-
return p.emitError(loc,
96-
"invalid overflow flag: expected nsw, nuw, or none");
97-
flags = flags | *flag;
98-
} while (succeeded(p.parseOptionalComma()));
99-
return p.parseGreater();
100-
}
101-
102-
static void printOverflowFlags(AsmPrinter &p, Operation *op,
103-
IntegerOverflowFlags flags) {
104-
if (flags == IntegerOverflowFlags::none)
105-
return;
106-
p << " overflow<";
107-
SmallVector<StringRef, 2> strs;
108-
if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw))
109-
strs.push_back("nsw");
110-
if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw))
111-
strs.push_back("nuw");
112-
llvm::interleaveComma(strs, p);
113-
p << ">";
114-
}
115-
11652
//===----------------------------------------------------------------------===//
11753
// Attribute Helpers
11854
//===----------------------------------------------------------------------===//

mlir/test/IR/enum-attr-invalid.mlir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,78 @@ func.func @test_parse_invalid_attr() -> () {
2828
// expected-error@+1 {{failed to parse TestEnumAttr parameter 'value'}}
2929
test.op_with_enum 1 : index
3030
}
31+
32+
// -----
33+
34+
func.func @test_non_keyword_prop_enum() -> () {
35+
// expected-error@+2 {{expected keyword for a test enum}}
36+
// expected-error@+1 {{invalid value for property value, expected a test enum}}
37+
test.op_with_enum_prop 0
38+
return
39+
}
40+
41+
// -----
42+
43+
func.func @test_wrong_keyword_prop_enum() -> () {
44+
// expected-error@+2 {{expected one of [first, second, third] for a test enum, got: fourth}}
45+
// expected-error@+1 {{invalid value for property value, expected a test enum}}
46+
test.op_with_enum_prop fourth
47+
}
48+
49+
// -----
50+
51+
func.func @test_bad_integer() -> () {
52+
// expected-error@+1 {{op property 'value' failed to satisfy constraint: a test enum}}
53+
"test.op_with_enum_prop"() <{value = 4 : i32}> {} : () -> ()
54+
}
55+
56+
// -----
57+
58+
func.func @test_bit_enum_prop_not_keyword() -> () {
59+
// expected-error@+2 {{expected keyword for a test bit enum}}
60+
// expected-error@+1 {{invalid value for property value1, expected a test bit enum}}
61+
test.op_with_bit_enum_prop 0
62+
return
63+
}
64+
65+
// -----
66+
67+
func.func @test_bit_enum_prop_wrong_keyword() -> () {
68+
// expected-error@+2 {{expected one of [read, write, execute] for a test bit enum, got: chroot}}
69+
// expected-error@+1 {{invalid value for property value1, expected a test bit enum}}
70+
test.op_with_bit_enum_prop read, chroot : ()
71+
return
72+
}
73+
74+
// -----
75+
76+
func.func @test_bit_enum_prop_bad_value() -> () {
77+
// expected-error@+1 {{op property 'value2' failed to satisfy constraint: a test bit enum}}
78+
"test.op_with_bit_enum_prop"() <{value1 = 7 : i32, value2 = 8 : i32}> {} : () -> ()
79+
return
80+
}
81+
82+
// -----
83+
84+
func.func @test_bit_enum_prop_named_wrong_keyword() -> () {
85+
// expected-error@+2 {{expected 'bit_enum'}}
86+
// expected-error@+1 {{invalid value for property value1, expected a test bit enum}}
87+
test.op_with_bit_enum_prop_named foo<read, execute>
88+
return
89+
}
90+
91+
// -----
92+
93+
func.func @test_bit_enum_prop_named_not_open() -> () {
94+
// expected-error@+2 {{expected '<'}}
95+
// expected-error@+1 {{invalid value for property value1, expected a test bit enum}}
96+
test.op_with_bit_enum_prop_named bit_enum read, execute>
97+
}
98+
99+
// -----
100+
101+
func.func @test_bit_enum_prop_named_not_closed() -> () {
102+
// expected-error@+2 {{expected '>'}}
103+
// expected-error@+1 {{invalid value for property value1, expected a test bit enum}}
104+
test.op_with_bit_enum_prop_named bit_enum<read, execute +
105+
}

0 commit comments

Comments
 (0)