Skip to content

Commit 159470d

Browse files
authored
[mlir] [tblgen-to-irdl] Add attributes to tblgen-to-irdl script (#109633)
Adds the ability to export attributes from the dialect and attributes of operations in the dialect
1 parent 8536d48 commit 159470d

File tree

3 files changed

+152
-7
lines changed

3 files changed

+152
-7
lines changed

mlir/include/mlir/IR/CommonAttrConstraints.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ class AnyAttrOf<list<Attr> allowedAttrs, string summary = "",
178178
summary)> {
179179
let returnType = cppType;
180180
let convertFromStorage = fromStorage;
181+
list<Attr> allowedAttributes = allowedAttrs;
181182
}
182183

183184
def LocationAttr : Attr<CPred<"::llvm::isa<::mlir::LocationAttr>($_self)">,
@@ -743,6 +744,8 @@ class ConfinedAttr<Attr attr, list<AttrConstraint> constraints> : Attr<
743744
let isOptional = attr.isOptional;
744745

745746
let baseAttr = attr;
747+
748+
list<AttrConstraint> attrConstraints = constraints;
746749
}
747750

748751
// An AttrConstraint that holds if all attr constraints specified in

mlir/test/tblgen-to-irdl/TestDialect.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
1313
let mnemonic = typeMnemonic;
1414
}
1515

16+
class Test_Attr<string name, string attrMnemonic> : AttrDef<Test_Dialect, name> {
17+
let mnemonic = attrMnemonic;
18+
}
19+
1620
class Test_Op<string mnemonic, list<Trait> traits = []>
1721
: Op<Test_Dialect, mnemonic, traits>;
1822

@@ -22,6 +26,8 @@ def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
2226
def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
2327
// CHECK: irdl.type @"!singleton_c"
2428
def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}
29+
// CHECK: irdl.attribute @"#test"
30+
def Test_TestAttr : Test_Attr<"Test", "test"> {}
2531

2632

2733
// Check that AllOfType is converted correctly.
@@ -45,6 +51,17 @@ def Test_AnyOp : Test_Op<"any"> {
4551
// CHECK-NEXT: irdl.operands(%[[v0]])
4652
// CHECK-NEXT: }
4753

54+
// Check attributes are converted correctly.
55+
def Test_AttributesOp : Test_Op<"attributes"> {
56+
let arguments = (ins I16Attr:$int_attr,
57+
Test_TestAttr:$test_attr);
58+
}
59+
// CHECK-LABEL: irdl.operation @attributes {
60+
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!builtin.integer"
61+
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base @test::@"#test"
62+
// CHECK-NEXT: irdl.attributes {"int_attr" = %[[v0]], "test_attr" = %[[v1]]}
63+
// CHECK-NEXT: }
64+
4865
// Check confined types are converted correctly.
4966
def Test_ConfinedOp : Test_Op<"confined"> {
5067
let arguments = (ins ConfinedType<AnyType, [CPred<"::llvm::isa<::mlir::TensorType>($_self)">]>:$tensor,

mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp

Lines changed: 132 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,14 @@ Value typeToConstraint(OpBuilder &builder, Type type) {
7474
return op.getOutput();
7575
}
7676

77-
std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
77+
Value baseToConstraint(OpBuilder &builder, StringRef baseClass) {
78+
MLIRContext *ctx = builder.getContext();
79+
auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
80+
StringAttr::get(ctx, baseClass));
81+
return op.getOutput();
82+
}
7883

84+
std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
7985
if (predRec.isSubClassOf("I")) {
8086
auto width = predRec.getValueAsInt("bitwidth");
8187
return IntegerType::get(ctx, width, IntegerType::Signless);
@@ -164,12 +170,12 @@ std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
164170
return std::nullopt;
165171
}
166172

167-
Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
173+
Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
168174
MLIRContext *ctx = builder.getContext();
169175
const Record &predRec = constraint.getDef();
170176

171177
if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
172-
return createConstraint(builder, predRec.getValueAsDef("baseType"));
178+
return createTypeConstraint(builder, predRec.getValueAsDef("baseType"));
173179

174180
if (predRec.getName() == "AnyType") {
175181
auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
@@ -196,7 +202,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
196202
std::vector<Value> constraints;
197203
for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
198204
constraints.push_back(
199-
createConstraint(builder, tblgen::Constraint(child)));
205+
createTypeConstraint(builder, tblgen::Constraint(child)));
200206
}
201207
auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
202208
return op.getOutput();
@@ -206,7 +212,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
206212
std::vector<Value> constraints;
207213
for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
208214
constraints.push_back(
209-
createConstraint(builder, tblgen::Constraint(child)));
215+
createTypeConstraint(builder, tblgen::Constraint(child)));
210216
}
211217
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
212218
return op.getOutput();
@@ -241,7 +247,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
241247
// Confined type
242248
if (predRec.isSubClassOf("ConfinedType")) {
243249
std::vector<Value> constraints;
244-
constraints.push_back(createConstraint(
250+
constraints.push_back(createTypeConstraint(
245251
builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
246252
for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
247253
constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
@@ -253,6 +259,85 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
253259
return createPredicate(builder, constraint.getPredicate());
254260
}
255261

262+
Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
263+
MLIRContext *ctx = builder.getContext();
264+
const Record &predRec = constraint.getDef();
265+
266+
if (predRec.isSubClassOf("DefaultValuedAttr") ||
267+
predRec.isSubClassOf("DefaultValuedOptionalAttr") ||
268+
predRec.isSubClassOf("OptionalAttr")) {
269+
return createAttrConstraint(builder, predRec.getValueAsDef("baseAttr"));
270+
}
271+
272+
if (predRec.isSubClassOf("ConfinedAttr")) {
273+
std::vector<Value> constraints;
274+
constraints.push_back(createAttrConstraint(
275+
builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr"))));
276+
for (Record *child : predRec.getValueAsListOfDefs("attrConstraints")) {
277+
constraints.push_back(createPredicate(
278+
builder, tblgen::Pred(child->getValueAsDef("predicate"))));
279+
}
280+
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
281+
return op.getOutput();
282+
}
283+
284+
if (predRec.isSubClassOf("AnyAttrOf")) {
285+
std::vector<Value> constraints;
286+
for (Record *child : predRec.getValueAsListOfDefs("allowedAttributes")) {
287+
constraints.push_back(
288+
createAttrConstraint(builder, tblgen::Constraint(child)));
289+
}
290+
auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
291+
return op.getOutput();
292+
}
293+
294+
if (predRec.getName() == "AnyAttr") {
295+
auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
296+
return op.getOutput();
297+
}
298+
299+
if (predRec.isSubClassOf("AnyIntegerAttrBase") ||
300+
predRec.isSubClassOf("SignlessIntegerAttrBase") ||
301+
predRec.isSubClassOf("SignedIntegerAttrBase") ||
302+
predRec.isSubClassOf("UnsignedIntegerAttrBase") ||
303+
predRec.isSubClassOf("BoolAttr")) {
304+
return baseToConstraint(builder, "!builtin.integer");
305+
}
306+
307+
if (predRec.isSubClassOf("FloatAttrBase")) {
308+
return baseToConstraint(builder, "!builtin.float");
309+
}
310+
311+
if (predRec.isSubClassOf("StringBasedAttr")) {
312+
return baseToConstraint(builder, "!builtin.string");
313+
}
314+
315+
if (predRec.getName() == "UnitAttr") {
316+
auto op =
317+
builder.create<irdl::IsOp>(UnknownLoc::get(ctx), UnitAttr::get(ctx));
318+
return op.getOutput();
319+
}
320+
321+
if (predRec.isSubClassOf("AttrDef")) {
322+
auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
323+
if (dialect == selectedDialect) {
324+
std::string combined = ("#" + predRec.getValueAsString("mnemonic")).str();
325+
SmallVector<FlatSymbolRefAttr> nested = {SymbolRefAttr::get(ctx, combined)
326+
327+
};
328+
auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
329+
auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
330+
return op.getOutput();
331+
}
332+
std::string typeName = ("#" + predRec.getValueAsString("attrName")).str();
333+
auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
334+
StringAttr::get(ctx, typeName));
335+
return op.getOutput();
336+
}
337+
338+
return createPredicate(builder, constraint.getPredicate());
339+
}
340+
256341
/// Returns the name of the operation without the dialect prefix.
257342
static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
258343
StringRef opName = tblgenOp.getDef().getValueAsString("opName");
@@ -265,6 +350,12 @@ static StringRef getTypeName(tblgen::TypeDef &tblgenType) {
265350
return opName;
266351
}
267352

353+
/// Returns the name of the attr without the dialect prefix.
354+
static StringRef getAttrName(tblgen::AttrDef &tblgenType) {
355+
StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
356+
return opName;
357+
}
358+
268359
/// Extract an operation to IRDL.
269360
irdl::OperationOp createIRDLOperation(OpBuilder &builder,
270361
tblgen::Operator &tblgenOp) {
@@ -282,7 +373,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
282373
SmallVector<Value> operands;
283374
SmallVector<irdl::VariadicityAttr> variadicity;
284375
for (const NamedTypeConstraint &namedCons : namedCons) {
285-
auto operand = createConstraint(consBuilder, namedCons.constraint);
376+
auto operand = createTypeConstraint(consBuilder, namedCons.constraint);
286377
operands.push_back(operand);
287378

288379
irdl::VariadicityAttr var;
@@ -304,13 +395,25 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
304395
auto [operands, operandVariadicity] = getValues(tblgenOp.getOperands());
305396
auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
306397

398+
SmallVector<Value> attributes;
399+
SmallVector<Attribute> attrNames;
400+
for (auto namedAttr : tblgenOp.getAttributes()) {
401+
if (namedAttr.attr.isOptional())
402+
continue;
403+
attributes.push_back(createAttrConstraint(consBuilder, namedAttr.attr));
404+
attrNames.push_back(StringAttr::get(ctx, namedAttr.name));
405+
}
406+
307407
// Create the operands and results operations.
308408
if (!operands.empty())
309409
consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
310410
operandVariadicity);
311411
if (!results.empty())
312412
consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
313413
resultVariadicity);
414+
if (!attributes.empty())
415+
consBuilder.create<irdl::AttributesOp>(UnknownLoc::get(ctx), attributes,
416+
ArrayAttr::get(ctx, attrNames));
314417

315418
return op;
316419
}
@@ -328,6 +431,20 @@ irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
328431
return op;
329432
}
330433

434+
irdl::AttributeOp createIRDLAttr(OpBuilder &builder,
435+
tblgen::AttrDef &tblgenAttr) {
436+
MLIRContext *ctx = builder.getContext();
437+
StringRef attrName = getAttrName(tblgenAttr);
438+
std::string combined = ("#" + attrName).str();
439+
440+
irdl::AttributeOp op = builder.create<irdl::AttributeOp>(
441+
UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
442+
443+
op.getBody().emplaceBlock();
444+
445+
return op;
446+
}
447+
331448
static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
332449
MLIRContext *ctx = builder.getContext();
333450
return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
@@ -358,6 +475,14 @@ static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper,
358475
createIRDLType(builder, tblgenType);
359476
}
360477

478+
for (const Record *attr :
479+
recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef")) {
480+
tblgen::AttrDef tblgenAttr(attr);
481+
if (tblgenAttr.getDialect().getName() != selectedDialect)
482+
continue;
483+
createIRDLAttr(builder, tblgenAttr);
484+
}
485+
361486
for (const Record *def :
362487
recordKeeper.getAllDerivedDefinitionsIfDefined("Op")) {
363488
tblgen::Operator tblgenOp(def);

0 commit comments

Comments
 (0)