Skip to content

Commit fee9054

Browse files
author
Vladislav Vinogradov
committed
[mlir][ODS] Support specialized Attribute class for Enums
Add a feature to `EnumAttr` definition to generate specialized Attribute class for the particular enumeration. This class will inherit `StringAttr` or `IntegerAttr` and will override `classof` and `getValue` methods. With this class the enumeration predicate can be checked with simple RTTI calls (`isa`, `dyn_cast`) and it will return the typed enumeration directly instead of raw string/integer. Based on the following discussion: https://llvm.discourse.group/t/rfc-add-enum-attribute-decorator-class/2252 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D97836
1 parent 2571a09 commit fee9054

File tree

20 files changed

+266
-96
lines changed

20 files changed

+266
-96
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def LLVM_ICmpOp : LLVM_Op<"icmp", [NoSideEffect]> {
200200
OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs),
201201
[{
202202
build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1),
203-
$_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
203+
predicate, lhs, rhs);
204204
}]>];
205205
let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }];
206206
let printer = [{ printICmpOp(p, *this); }];
@@ -246,14 +246,6 @@ def LLVM_FCmpOp : LLVM_Op<"fcmp", [
246246
let llvmBuilder = [{
247247
$res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
248248
}];
249-
let builders = [
250-
OpBuilder<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs,
251-
CArg<"FastmathFlags", "{}">:$fmf),
252-
[{
253-
build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1),
254-
$_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs,
255-
::mlir::LLVM::FMFAttr::get($_builder.getContext(), fmf));
256-
}]>];
257249
let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }];
258250
let printer = [{ printFCmpOp(p, *this); }];
259251
}

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVEnums.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVENUMS_H_
1414
#define MLIR_DIALECT_SPIRV_IR_SPIRVENUMS_H_
1515

16+
#include "mlir/IR/BuiltinAttributes.h"
1617
#include "mlir/Support/LLVM.h"
1718
#include "llvm/ADT/DenseMapInfo.h"
1819
#include "llvm/ADT/StringRef.h"

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def SPV_LoadOp : SPV_Op<"Load", []> {
184184

185185
let builders = [
186186
OpBuilder<(ins "Value":$basePtr,
187-
CArg<"IntegerAttr", "{}">:$memory_access,
187+
CArg<"MemoryAccessAttr", "{}">:$memory_access,
188188
CArg<"IntegerAttr", "{}">:$alignment)>
189189
];
190190
}

mlir/include/mlir/Dialect/Vector/VectorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def CombiningKind : BitEnumAttr<
5353
COMBINING_KIND_MAX, COMBINING_KIND_AND, COMBINING_KIND_OR,
5454
COMBINING_KIND_XOR]> {
5555
let cppNamespace = "::mlir::vector";
56+
let genSpecializedAttr = 0;
5657
}
5758

5859
def Vector_CombiningKindAttr : DialectAttr<

mlir/include/mlir/IR/OpBase.td

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,7 +1142,9 @@ class BitEnumAttrCase<string sym, int val, string str = sym> :
11421142
}
11431143

11441144
// Additional information for an enum attribute.
1145-
class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> {
1145+
class EnumAttrInfo<
1146+
string name, list<EnumAttrCaseInfo> cases, Attr baseClass> :
1147+
Attr<baseClass.predicate, baseClass.summary> {
11461148
// The C++ enum class name
11471149
string className = name;
11481150

@@ -1188,54 +1190,73 @@ class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> {
11881190
// static constexpr unsigned <fn-name>();
11891191
// ```
11901192
string maxEnumValFnName = "getMaxEnumValFor" # name;
1193+
1194+
// Generate specialized Attribute class
1195+
bit genSpecializedAttr = 1;
1196+
// The underlying Attribute class, which holds the enum value
1197+
Attr baseAttrClass = baseClass;
1198+
// The name of specialized Enum Attribute class
1199+
string specializedAttrClassName = name # Attr;
1200+
1201+
// Override Attr class fields for specialized class
1202+
let predicate = !if(genSpecializedAttr,
1203+
CPred<"$_self.isa<" # cppNamespace # "::" # specializedAttrClassName # ">()">,
1204+
baseAttrClass.predicate);
1205+
let storageType = !if(genSpecializedAttr,
1206+
cppNamespace # "::" # specializedAttrClassName,
1207+
baseAttrClass.storageType);
1208+
let returnType = !if(genSpecializedAttr,
1209+
cppNamespace # "::" # className,
1210+
baseAttrClass.returnType);
1211+
let constBuilderCall = !if(genSpecializedAttr,
1212+
cppNamespace # "::" # specializedAttrClassName # "::get($_builder.getContext(), $0)",
1213+
baseAttrClass.constBuilderCall);
1214+
let valueType = baseAttrClass.valueType;
11911215
}
11921216

11931217
// An enum attribute backed by StringAttr.
11941218
//
11951219
// Op attributes of this kind are stored as StringAttr. Extra verification will
11961220
// be generated on the string though: only the symbols of the allowed cases are
11971221
// permitted as the string value.
1198-
class StrEnumAttr<string name, string summary, list<StrEnumAttrCase> cases>
1199-
: EnumAttrInfo<name, cases>,
1222+
class StrEnumAttr<string name, string summary, list<StrEnumAttrCase> cases> :
1223+
EnumAttrInfo<name, cases,
12001224
StringBasedAttr<
12011225
And<[StrAttr.predicate, Or<!foreach(case, cases, case.predicate)>]>,
12021226
!if(!empty(summary), "allowed string cases: " #
12031227
!interleave(!foreach(case, cases, "'" # case.symbol # "'"), ", "),
1204-
summary)>;
1228+
summary)>> {
1229+
// Disable specialized Attribute class for `StringAttr` backend by default.
1230+
let genSpecializedAttr = 0;
1231+
}
12051232

12061233
// An enum attribute backed by IntegerAttr.
12071234
//
12081235
// Op attributes of this kind are stored as IntegerAttr. Extra verification will
12091236
// be generated on the integer though: only the values of the allowed cases are
12101237
// permitted as the integer value.
1211-
class IntEnumAttr<I intType, string name, string summary,
1212-
list<IntEnumAttrCaseBase> cases> :
1213-
EnumAttrInfo<name, cases>,
1214-
SignlessIntegerAttrBase<intType,
1215-
!if(!empty(summary), "allowed " # intType.summary # " cases: " #
1216-
!interleave(!foreach(case, cases, case.value), ", "), summary)> {
1238+
class IntEnumAttrBase<I intType, list<IntEnumAttrCaseBase> cases, string summary> :
1239+
SignlessIntegerAttrBase<intType, summary> {
12171240
let predicate = And<[
1218-
SignlessIntegerAttrBase<intType, "">.predicate,
1241+
SignlessIntegerAttrBase<intType, summary>.predicate,
12191242
Or<!foreach(case, cases, case.predicate)>]>;
12201243
}
12211244

1222-
class I32EnumAttr<string name, string summary,
1223-
list<I32EnumAttrCase> cases> :
1245+
class IntEnumAttr<I intType, string name, string summary,
1246+
list<IntEnumAttrCaseBase> cases> :
1247+
EnumAttrInfo<name, cases,
1248+
IntEnumAttrBase<intType, cases,
1249+
!if(!empty(summary), "allowed " # intType.summary # " cases: " #
1250+
!interleave(!foreach(case, cases, case.value), ", "),
1251+
summary)>>;
1252+
1253+
class I32EnumAttr<string name, string summary, list<I32EnumAttrCase> cases> :
12241254
IntEnumAttr<I32, name, summary, cases> {
1225-
let returnType = cppNamespace # "::" # name;
12261255
let underlyingType = "uint32_t";
1227-
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
1228-
let constBuilderCall =
1229-
"$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
12301256
}
1231-
class I64EnumAttr<string name, string summary,
1232-
list<I64EnumAttrCase> cases> :
1257+
class I64EnumAttr<string name, string summary, list<I64EnumAttrCase> cases> :
12331258
IntEnumAttr<I64, name, summary, cases> {
1234-
let returnType = cppNamespace # "::" # name;
12351259
let underlyingType = "uint64_t";
1236-
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
1237-
let constBuilderCall =
1238-
"$_builder.getI64IntegerAttr(static_cast<int64_t>($0))";
12391260
}
12401261

12411262
// A bit enum stored with 32-bit IntegerAttr.
@@ -1244,22 +1265,20 @@ class I64EnumAttr<string name, string summary,
12441265
// be generated on the integer to make sure only allowed bit are set. Besides,
12451266
// helper methods are generated to parse a string separated with a specified
12461267
// delimiter to a symbol and vice versa.
1247-
class BitEnumAttr<string name, string summary,
1248-
list<BitEnumAttrCase> cases> :
1249-
EnumAttrInfo<name, cases>, SignlessIntegerAttrBase<I32, summary> {
1268+
class BitEnumAttrBase<list<BitEnumAttrCase> cases, string summary> :
1269+
SignlessIntegerAttrBase<I32, summary> {
12501270
let predicate = And<[
12511271
I32Attr.predicate,
12521272
// Make sure we don't have unknown bit set.
12531273
CPred<"!($_self.cast<::mlir::IntegerAttr>().getValue().getZExtValue() & (~("
12541274
# !interleave(!foreach(case, cases, case.value # "u"), "|") #
12551275
")))">
12561276
]>;
1277+
}
12571278

1258-
let returnType = cppNamespace # "::" # name;
1279+
class BitEnumAttr<string name, string summary, list<BitEnumAttrCase> cases> :
1280+
EnumAttrInfo<name, cases, BitEnumAttrBase<cases, summary>> {
12591281
let underlyingType = "uint32_t";
1260-
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
1261-
let constBuilderCall =
1262-
"$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
12631282

12641283
// We need to return a string because we may concatenate symbols for multiple
12651284
// bits together.

mlir/include/mlir/TableGen/Attribute.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ class EnumAttr : public Attribute {
202202

203203
// Returns all allowed cases for this enum attribute.
204204
std::vector<EnumAttrCase> getAllCases() const;
205+
206+
bool genSpecializedAttr() const;
207+
llvm::Record *getBaseAttrClass() const;
208+
StringRef getSpecializedAttrClassName() const;
205209
};
206210

207211
class StructFieldAttr {

mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
155155
// header to merge.
156156
scf::ForOpAdaptor forOperands(operands);
157157
auto loc = forOp.getLoc();
158-
auto loopControl = rewriter.getI32IntegerAttr(
159-
static_cast<uint32_t>(spirv::LoopControl::None));
160-
auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
158+
auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
161159
loopOp.addEntryAndMergeBlock();
162160

163161
OpBuilder::InsertionGuard guard(rewriter);
@@ -238,11 +236,9 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
238236
scf::IfOpAdaptor ifOperands(operands);
239237
auto loc = ifOp.getLoc();
240238

241-
// Create `spv.mlir.selection` operation, selection header block and merge
242-
// block.
243-
auto selectionControl = rewriter.getI32IntegerAttr(
244-
static_cast<uint32_t>(spirv::SelectionControl::None));
245-
auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl);
239+
// Create `spv.selection` operation, selection header block and merge block.
240+
auto selectionOp =
241+
rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
246242
auto *mergeBlock =
247243
rewriter.createBlock(&selectionOp.body(), selectionOp.body().end());
248244
rewriter.create<spirv::MergeOp>(loc);

mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -826,10 +826,8 @@ class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
826826
return failure();
827827

828828
rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
829-
operation, dstType,
830-
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
831-
operation.operand1(), operation.operand2(),
832-
LLVM::FMFAttr::get(operation.getContext(), {}));
829+
operation, dstType, predicate, operation.operand1(),
830+
operation.operand2());
833831
return success();
834832
}
835833
};
@@ -849,9 +847,8 @@ class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
849847
return failure();
850848

851849
rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
852-
operation, dstType,
853-
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
854-
operation.operand1(), operation.operand2());
850+
operation, dstType, predicate, operation.operand1(),
851+
operation.operand2());
855852
return success();
856853
}
857854
};

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3069,8 +3069,7 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> {
30693069

30703070
rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
30713071
cmpiOp, typeConverter->convertType(cmpiOp.getResult().getType()),
3072-
rewriter.getI64IntegerAttr(static_cast<int64_t>(
3073-
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
3072+
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()),
30743073
transformed.lhs(), transformed.rhs());
30753074

30763075
return success();
@@ -3085,12 +3084,10 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
30853084
ConversionPatternRewriter &rewriter) const override {
30863085
CmpFOpAdaptor transformed(operands);
30873086

3088-
auto fmf = LLVM::FMFAttr::get(cmpfOp.getContext(), {});
30893087
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
30903088
cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()),
3091-
rewriter.getI64IntegerAttr(static_cast<int64_t>(
3092-
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
3093-
transformed.lhs(), transformed.rhs(), fmf);
3089+
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()),
3090+
transformed.lhs(), transformed.rhs());
30943091

30953092
return success();
30963093
}

mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
10171017
srcBits, dstBits, rewriter);
10181018
Value spvLoadOp = rewriter.create<spirv::LoadOp>(
10191019
loc, dstType, adjustedPtr,
1020-
loadOp->getAttrOfType<IntegerAttr>(
1020+
loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
10211021
spirv::attributeName<spirv::MemoryAccess>()),
10221022
loadOp->getAttrOfType<IntegerAttr>("alignment"));
10231023

mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ ParallelLoopDimMapping getParallelLoopDimMappingAttr(Processor processor,
3636
MLIRContext *context = map.getContext();
3737
OpBuilder builder(context);
3838
return ParallelLoopDimMapping::get(
39-
builder.getI64IntegerAttr(static_cast<int32_t>(processor)),
39+
ProcessorAttr::get(builder.getContext(), processor),
4040
AffineMapAttr::get(map), AffineMapAttr::get(bound), context);
4141
}
4242

mlir/lib/Dialect/SPIRV/IR/SPIRVEnums.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
1414

15+
#include "mlir/IR/BuiltinTypes.h"
16+
1517
#include "llvm/ADT/SetVector.h"
1618
#include "llvm/ADT/StringExtras.h"
1719
#include "llvm/ADT/StringRef.h"

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,7 +1659,7 @@ void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
16591659
spirv::FuncOp function,
16601660
ArrayRef<Attribute> interfaceVars) {
16611661
build(builder, state,
1662-
builder.getI32IntegerAttr(static_cast<int32_t>(executionModel)),
1662+
spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
16631663
builder.getSymbolRefAttr(function),
16641664
builder.getArrayAttr(interfaceVars));
16651665
}
@@ -1721,7 +1721,7 @@ void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
17211721
spirv::ExecutionMode executionMode,
17221722
ArrayRef<int32_t> params) {
17231723
build(builder, state, builder.getSymbolRefAttr(function),
1724-
builder.getI32IntegerAttr(static_cast<int32_t>(executionMode)),
1724+
spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
17251725
builder.getI32ArrayAttr(params));
17261726
}
17271727

@@ -2243,10 +2243,10 @@ static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
22432243
//===----------------------------------------------------------------------===//
22442244

22452245
void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
2246-
Value basePtr, IntegerAttr memory_access,
2246+
Value basePtr, MemoryAccessAttr memoryAccess,
22472247
IntegerAttr alignment) {
22482248
auto ptrType = basePtr.getType().cast<spirv::PointerType>();
2249-
build(builder, state, ptrType.getPointeeType(), basePtr, memory_access,
2249+
build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
22502250
alignment);
22512251
}
22522252

@@ -2784,9 +2784,8 @@ void spirv::SelectionOp::addMergeBlock() {
27842784
spirv::SelectionOp spirv::SelectionOp::createIfThen(
27852785
Location loc, Value condition,
27862786
function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
2787-
auto selectionControl = builder.getI32IntegerAttr(
2788-
static_cast<uint32_t>(spirv::SelectionControl::None));
2789-
auto selectionOp = builder.create<spirv::SelectionOp>(loc, selectionControl);
2787+
auto selectionOp =
2788+
builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
27902789

27912790
selectionOp.addMergeBlock();
27922791
Block *mergeBlock = selectionOp.getMergeBlock();

mlir/lib/TableGen/Attribute.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,18 @@ std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
231231
return cases;
232232
}
233233

234+
bool EnumAttr::genSpecializedAttr() const {
235+
return def->getValueAsBit("genSpecializedAttr");
236+
}
237+
238+
llvm::Record *EnumAttr::getBaseAttrClass() const {
239+
return def->getValueAsDef("baseAttrClass");
240+
}
241+
242+
StringRef EnumAttr::getSpecializedAttrClassName() const {
243+
return def->getValueAsString("specializedAttrClassName");
244+
}
245+
234246
StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) {
235247
assert(def->isSubClassOf("StructFieldAttr") &&
236248
"must be subclass of TableGen 'StructFieldAttr' class");

0 commit comments

Comments
 (0)