-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] [tblgen-to-irdl] Add attributes to tblgen-to-irdl script #109633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-ods Author: Alex Rice (alexarice) ChangesAdds the ability to export attributes from the dialect and attributes of operations in the dialect @math-fehr Full diff: https://github.com/llvm/llvm-project/pull/109633.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 853fb318c76e71..de5f6797235e3c 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -178,6 +178,7 @@ class AnyAttrOf<list<Attr> allowedAttrs, string summary = "",
summary)> {
let returnType = cppType;
let convertFromStorage = fromStorage;
+ list<Attr> allowedAttributes = allowedAttrs;
}
def LocationAttr : Attr<CPred<"::llvm::isa<::mlir::LocationAttr>($_self)">,
@@ -743,6 +744,8 @@ class ConfinedAttr<Attr attr, list<AttrConstraint> constraints> : Attr<
let isOptional = attr.isOptional;
let baseAttr = attr;
+
+ list<AttrConstraint> attrConstraints = constraints;
}
// An AttrConstraint that holds if all attr constraints specified in
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
index 4fea3d8576e9ab..1ba84a5d3683d4 100644
--- a/mlir/test/tblgen-to-irdl/TestDialect.td
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -13,6 +13,10 @@ class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
let mnemonic = typeMnemonic;
}
+class Test_Attr<string name, string attrMnemonic> : AttrDef<Test_Dialect, name> {
+ let mnemonic = attrMnemonic;
+}
+
class Test_Op<string mnemonic, list<Trait> traits = []>
: Op<Test_Dialect, mnemonic, traits>;
@@ -22,6 +26,8 @@ def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
// CHECK: irdl.type @"!singleton_c"
def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}
+// CHECK: irdl.attribute @"#test"
+def Test_TestAttr : Test_Attr<"Test", "test"> {}
// Check that AllOfType is converted correctly.
@@ -45,6 +51,17 @@ def Test_AnyOp : Test_Op<"any"> {
// CHECK-NEXT: irdl.operands(%[[v0]])
// CHECK-NEXT: }
+// Check attributes are converted correctly.
+def Test_AttributesOp : Test_Op<"attributes"> {
+ let arguments = (ins I16Attr:$int_attr,
+ Test_TestAttr:$test_attr);
+}
+// CHECK-LABEL: irdl.operation @attributes {
+// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!builtin.integer"
+// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base @test::@"#test"
+// CHECK-NEXT: irdl.attributes {"int_attr" = %[[v0]], "test_attr" = %[[v1]]}
+// CHECK-NEXT: }
+
// Check confined types are converted correctly.
def Test_ConfinedOp : Test_Op<"confined"> {
let arguments = (ins ConfinedType<AnyType, [CPred<"::llvm::isa<::mlir::TensorType>($_self)">]>:$tensor,
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index 45957bafc378e3..d0a3552fb123da 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -74,8 +74,14 @@ Value typeToConstraint(OpBuilder &builder, Type type) {
return op.getOutput();
}
-std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
+Value baseToConstraint(OpBuilder &builder, StringRef baseClass) {
+ MLIRContext *ctx = builder.getContext();
+ auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+ StringAttr::get(ctx, baseClass));
+ return op.getOutput();
+}
+std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
if (predRec.isSubClassOf("I")) {
auto width = predRec.getValueAsInt("bitwidth");
return IntegerType::get(ctx, width, IntegerType::Signless);
@@ -164,12 +170,12 @@ std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
return std::nullopt;
}
-Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
const Record &predRec = constraint.getDef();
if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
- return createConstraint(builder, predRec.getValueAsDef("baseType"));
+ return createTypeConstraint(builder, predRec.getValueAsDef("baseType"));
if (predRec.getName() == "AnyType") {
auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
@@ -196,7 +202,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
std::vector<Value> constraints;
for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
- createConstraint(builder, tblgen::Constraint(child)));
+ createTypeConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
@@ -206,7 +212,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
std::vector<Value> constraints;
for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
- createConstraint(builder, tblgen::Constraint(child)));
+ createTypeConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
@@ -241,7 +247,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
// Confined type
if (predRec.isSubClassOf("ConfinedType")) {
std::vector<Value> constraints;
- constraints.push_back(createConstraint(
+ constraints.push_back(createTypeConstraint(
builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
@@ -253,6 +259,85 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
return createPredicate(builder, constraint.getPredicate());
}
+Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+ MLIRContext *ctx = builder.getContext();
+ const Record &predRec = constraint.getDef();
+
+ if (predRec.isSubClassOf("DefaultValuedAttr") ||
+ predRec.isSubClassOf("DefaultValuedOptionalAttr") ||
+ predRec.isSubClassOf("OptionalAttr")) {
+ return createAttrConstraint(builder, predRec.getValueAsDef("baseAttr"));
+ }
+
+ if (predRec.isSubClassOf("ConfinedAttr")) {
+ std::vector<Value> constraints;
+ constraints.push_back(createAttrConstraint(
+ builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr"))));
+ for (Record *child : predRec.getValueAsListOfDefs("attrConstraints")) {
+ constraints.push_back(createPredicate(
+ builder, tblgen::Pred(child->getValueAsDef("predicate"))));
+ }
+ auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AnyAttrOf")) {
+ std::vector<Value> constraints;
+ for (Record *child : predRec.getValueAsListOfDefs("allowedAttributes")) {
+ constraints.push_back(
+ createAttrConstraint(builder, tblgen::Constraint(child)));
+ }
+ auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+
+ if (predRec.getName() == "AnyAttr") {
+ auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AnyIntegerAttrBase") ||
+ predRec.isSubClassOf("SignlessIntegerAttrBase") ||
+ predRec.isSubClassOf("SignedIntegerAttrBase") ||
+ predRec.isSubClassOf("UnsignedIntegerAttrBase") ||
+ predRec.isSubClassOf("BoolAttr")) {
+ return baseToConstraint(builder, "!builtin.integer");
+ }
+
+ if (predRec.isSubClassOf("FloatAttrBase")) {
+ return baseToConstraint(builder, "!builtin.float");
+ }
+
+ if (predRec.isSubClassOf("StringBasedAttr")) {
+ return baseToConstraint(builder, "!builtin.string");
+ }
+
+ if (predRec.getName() == "UnitAttr") {
+ auto op =
+ builder.create<irdl::IsOp>(UnknownLoc::get(ctx), UnitAttr::get(ctx));
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AttrDef")) {
+ auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
+ if (dialect == selectedDialect) {
+ std::string combined = ("#" + predRec.getValueAsString("mnemonic")).str();
+ SmallVector<FlatSymbolRefAttr> nested = {SymbolRefAttr::get(ctx, combined)
+
+ };
+ auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
+ auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
+ return op.getOutput();
+ }
+ std::string typeName = ("#" + predRec.getValueAsString("attrName")).str();
+ auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+ StringAttr::get(ctx, typeName));
+ return op.getOutput();
+ }
+
+ return createPredicate(builder, constraint.getPredicate());
+}
+
/// Returns the name of the operation without the dialect prefix.
static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
StringRef opName = tblgenOp.getDef().getValueAsString("opName");
@@ -265,6 +350,12 @@ static StringRef getTypeName(tblgen::TypeDef &tblgenType) {
return opName;
}
+/// Returns the name of the attr without the dialect prefix.
+static StringRef getAttrName(tblgen::AttrDef &tblgenType) {
+ StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
+ return opName;
+}
+
/// Extract an operation to IRDL.
irdl::OperationOp createIRDLOperation(OpBuilder &builder,
tblgen::Operator &tblgenOp) {
@@ -282,7 +373,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
SmallVector<Value> operands;
SmallVector<irdl::VariadicityAttr> variadicity;
for (const NamedTypeConstraint &namedCons : namedCons) {
- auto operand = createConstraint(consBuilder, namedCons.constraint);
+ auto operand = createTypeConstraint(consBuilder, namedCons.constraint);
operands.push_back(operand);
irdl::VariadicityAttr var;
@@ -304,6 +395,15 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
auto [operands, operandVariadicity] = getValues(tblgenOp.getOperands());
auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
+ SmallVector<Value> attributes;
+ SmallVector<Attribute> attrNames;
+ for (auto namedAttr : tblgenOp.getAttributes()) {
+ if (namedAttr.attr.isOptional())
+ continue;
+ attributes.push_back(createAttrConstraint(consBuilder, namedAttr.attr));
+ attrNames.push_back(StringAttr::get(ctx, namedAttr.name));
+ }
+
// Create the operands and results operations.
if (!operands.empty())
consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
@@ -311,6 +411,9 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
if (!results.empty())
consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
resultVariadicity);
+ if (!attributes.empty())
+ consBuilder.create<irdl::AttributesOp>(UnknownLoc::get(ctx), attributes,
+ ArrayAttr::get(ctx, attrNames));
return op;
}
@@ -328,6 +431,20 @@ irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
return op;
}
+irdl::AttributeOp createIRDLAttr(OpBuilder &builder,
+ tblgen::AttrDef &tblgenAttr) {
+ MLIRContext *ctx = builder.getContext();
+ StringRef attrName = getAttrName(tblgenAttr);
+ std::string combined = ("#" + attrName).str();
+
+ irdl::AttributeOp op = builder.create<irdl::AttributeOp>(
+ UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
+
+ op.getBody().emplaceBlock();
+
+ return op;
+}
+
static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
MLIRContext *ctx = builder.getContext();
return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
@@ -358,6 +475,14 @@ static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper,
createIRDLType(builder, tblgenType);
}
+ for (const Record *attr :
+ recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef")) {
+ tblgen::AttrDef tblgenAttr(attr);
+ if (tblgenAttr.getDialect().getName() != selectedDialect)
+ continue;
+ createIRDLAttr(builder, tblgenAttr);
+ }
+
for (const Record *def :
recordKeeper.getAllDerivedDefinitionsIfDefined("Op")) {
tblgen::Operator tblgenOp(def);
|
@llvm/pr-subscribers-mlir-core Author: Alex Rice (alexarice) ChangesAdds the ability to export attributes from the dialect and attributes of operations in the dialect @math-fehr Full diff: https://github.com/llvm/llvm-project/pull/109633.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 853fb318c76e71..de5f6797235e3c 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -178,6 +178,7 @@ class AnyAttrOf<list<Attr> allowedAttrs, string summary = "",
summary)> {
let returnType = cppType;
let convertFromStorage = fromStorage;
+ list<Attr> allowedAttributes = allowedAttrs;
}
def LocationAttr : Attr<CPred<"::llvm::isa<::mlir::LocationAttr>($_self)">,
@@ -743,6 +744,8 @@ class ConfinedAttr<Attr attr, list<AttrConstraint> constraints> : Attr<
let isOptional = attr.isOptional;
let baseAttr = attr;
+
+ list<AttrConstraint> attrConstraints = constraints;
}
// An AttrConstraint that holds if all attr constraints specified in
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
index 4fea3d8576e9ab..1ba84a5d3683d4 100644
--- a/mlir/test/tblgen-to-irdl/TestDialect.td
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -13,6 +13,10 @@ class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
let mnemonic = typeMnemonic;
}
+class Test_Attr<string name, string attrMnemonic> : AttrDef<Test_Dialect, name> {
+ let mnemonic = attrMnemonic;
+}
+
class Test_Op<string mnemonic, list<Trait> traits = []>
: Op<Test_Dialect, mnemonic, traits>;
@@ -22,6 +26,8 @@ def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
// CHECK: irdl.type @"!singleton_c"
def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}
+// CHECK: irdl.attribute @"#test"
+def Test_TestAttr : Test_Attr<"Test", "test"> {}
// Check that AllOfType is converted correctly.
@@ -45,6 +51,17 @@ def Test_AnyOp : Test_Op<"any"> {
// CHECK-NEXT: irdl.operands(%[[v0]])
// CHECK-NEXT: }
+// Check attributes are converted correctly.
+def Test_AttributesOp : Test_Op<"attributes"> {
+ let arguments = (ins I16Attr:$int_attr,
+ Test_TestAttr:$test_attr);
+}
+// CHECK-LABEL: irdl.operation @attributes {
+// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!builtin.integer"
+// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base @test::@"#test"
+// CHECK-NEXT: irdl.attributes {"int_attr" = %[[v0]], "test_attr" = %[[v1]]}
+// CHECK-NEXT: }
+
// Check confined types are converted correctly.
def Test_ConfinedOp : Test_Op<"confined"> {
let arguments = (ins ConfinedType<AnyType, [CPred<"::llvm::isa<::mlir::TensorType>($_self)">]>:$tensor,
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index 45957bafc378e3..d0a3552fb123da 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -74,8 +74,14 @@ Value typeToConstraint(OpBuilder &builder, Type type) {
return op.getOutput();
}
-std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
+Value baseToConstraint(OpBuilder &builder, StringRef baseClass) {
+ MLIRContext *ctx = builder.getContext();
+ auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+ StringAttr::get(ctx, baseClass));
+ return op.getOutput();
+}
+std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
if (predRec.isSubClassOf("I")) {
auto width = predRec.getValueAsInt("bitwidth");
return IntegerType::get(ctx, width, IntegerType::Signless);
@@ -164,12 +170,12 @@ std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
return std::nullopt;
}
-Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
const Record &predRec = constraint.getDef();
if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
- return createConstraint(builder, predRec.getValueAsDef("baseType"));
+ return createTypeConstraint(builder, predRec.getValueAsDef("baseType"));
if (predRec.getName() == "AnyType") {
auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
@@ -196,7 +202,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
std::vector<Value> constraints;
for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
- createConstraint(builder, tblgen::Constraint(child)));
+ createTypeConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
@@ -206,7 +212,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
std::vector<Value> constraints;
for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
- createConstraint(builder, tblgen::Constraint(child)));
+ createTypeConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
@@ -241,7 +247,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
// Confined type
if (predRec.isSubClassOf("ConfinedType")) {
std::vector<Value> constraints;
- constraints.push_back(createConstraint(
+ constraints.push_back(createTypeConstraint(
builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
@@ -253,6 +259,85 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
return createPredicate(builder, constraint.getPredicate());
}
+Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+ MLIRContext *ctx = builder.getContext();
+ const Record &predRec = constraint.getDef();
+
+ if (predRec.isSubClassOf("DefaultValuedAttr") ||
+ predRec.isSubClassOf("DefaultValuedOptionalAttr") ||
+ predRec.isSubClassOf("OptionalAttr")) {
+ return createAttrConstraint(builder, predRec.getValueAsDef("baseAttr"));
+ }
+
+ if (predRec.isSubClassOf("ConfinedAttr")) {
+ std::vector<Value> constraints;
+ constraints.push_back(createAttrConstraint(
+ builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr"))));
+ for (Record *child : predRec.getValueAsListOfDefs("attrConstraints")) {
+ constraints.push_back(createPredicate(
+ builder, tblgen::Pred(child->getValueAsDef("predicate"))));
+ }
+ auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AnyAttrOf")) {
+ std::vector<Value> constraints;
+ for (Record *child : predRec.getValueAsListOfDefs("allowedAttributes")) {
+ constraints.push_back(
+ createAttrConstraint(builder, tblgen::Constraint(child)));
+ }
+ auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+
+ if (predRec.getName() == "AnyAttr") {
+ auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AnyIntegerAttrBase") ||
+ predRec.isSubClassOf("SignlessIntegerAttrBase") ||
+ predRec.isSubClassOf("SignedIntegerAttrBase") ||
+ predRec.isSubClassOf("UnsignedIntegerAttrBase") ||
+ predRec.isSubClassOf("BoolAttr")) {
+ return baseToConstraint(builder, "!builtin.integer");
+ }
+
+ if (predRec.isSubClassOf("FloatAttrBase")) {
+ return baseToConstraint(builder, "!builtin.float");
+ }
+
+ if (predRec.isSubClassOf("StringBasedAttr")) {
+ return baseToConstraint(builder, "!builtin.string");
+ }
+
+ if (predRec.getName() == "UnitAttr") {
+ auto op =
+ builder.create<irdl::IsOp>(UnknownLoc::get(ctx), UnitAttr::get(ctx));
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AttrDef")) {
+ auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
+ if (dialect == selectedDialect) {
+ std::string combined = ("#" + predRec.getValueAsString("mnemonic")).str();
+ SmallVector<FlatSymbolRefAttr> nested = {SymbolRefAttr::get(ctx, combined)
+
+ };
+ auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
+ auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
+ return op.getOutput();
+ }
+ std::string typeName = ("#" + predRec.getValueAsString("attrName")).str();
+ auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+ StringAttr::get(ctx, typeName));
+ return op.getOutput();
+ }
+
+ return createPredicate(builder, constraint.getPredicate());
+}
+
/// Returns the name of the operation without the dialect prefix.
static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
StringRef opName = tblgenOp.getDef().getValueAsString("opName");
@@ -265,6 +350,12 @@ static StringRef getTypeName(tblgen::TypeDef &tblgenType) {
return opName;
}
+/// Returns the name of the attr without the dialect prefix.
+static StringRef getAttrName(tblgen::AttrDef &tblgenType) {
+ StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
+ return opName;
+}
+
/// Extract an operation to IRDL.
irdl::OperationOp createIRDLOperation(OpBuilder &builder,
tblgen::Operator &tblgenOp) {
@@ -282,7 +373,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
SmallVector<Value> operands;
SmallVector<irdl::VariadicityAttr> variadicity;
for (const NamedTypeConstraint &namedCons : namedCons) {
- auto operand = createConstraint(consBuilder, namedCons.constraint);
+ auto operand = createTypeConstraint(consBuilder, namedCons.constraint);
operands.push_back(operand);
irdl::VariadicityAttr var;
@@ -304,6 +395,15 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
auto [operands, operandVariadicity] = getValues(tblgenOp.getOperands());
auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
+ SmallVector<Value> attributes;
+ SmallVector<Attribute> attrNames;
+ for (auto namedAttr : tblgenOp.getAttributes()) {
+ if (namedAttr.attr.isOptional())
+ continue;
+ attributes.push_back(createAttrConstraint(consBuilder, namedAttr.attr));
+ attrNames.push_back(StringAttr::get(ctx, namedAttr.name));
+ }
+
// Create the operands and results operations.
if (!operands.empty())
consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
@@ -311,6 +411,9 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
if (!results.empty())
consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
resultVariadicity);
+ if (!attributes.empty())
+ consBuilder.create<irdl::AttributesOp>(UnknownLoc::get(ctx), attributes,
+ ArrayAttr::get(ctx, attrNames));
return op;
}
@@ -328,6 +431,20 @@ irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
return op;
}
+irdl::AttributeOp createIRDLAttr(OpBuilder &builder,
+ tblgen::AttrDef &tblgenAttr) {
+ MLIRContext *ctx = builder.getContext();
+ StringRef attrName = getAttrName(tblgenAttr);
+ std::string combined = ("#" + attrName).str();
+
+ irdl::AttributeOp op = builder.create<irdl::AttributeOp>(
+ UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
+
+ op.getBody().emplaceBlock();
+
+ return op;
+}
+
static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
MLIRContext *ctx = builder.getContext();
return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
@@ -358,6 +475,14 @@ static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper,
createIRDLType(builder, tblgenType);
}
+ for (const Record *attr :
+ recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef")) {
+ tblgen::AttrDef tblgenAttr(attr);
+ if (tblgenAttr.getDialect().getName() != selectedDialect)
+ continue;
+ createIRDLAttr(builder, tblgenAttr);
+ }
+
for (const Record *def :
recordKeeper.getAllDerivedDefinitionsIfDefined("Op")) {
tblgen::Operator tblgenOp(def);
|
@llvm/pr-subscribers-mlir Author: Alex Rice (alexarice) ChangesAdds the ability to export attributes from the dialect and attributes of operations in the dialect @math-fehr Full diff: https://github.com/llvm/llvm-project/pull/109633.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 853fb318c76e71..de5f6797235e3c 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -178,6 +178,7 @@ class AnyAttrOf<list<Attr> allowedAttrs, string summary = "",
summary)> {
let returnType = cppType;
let convertFromStorage = fromStorage;
+ list<Attr> allowedAttributes = allowedAttrs;
}
def LocationAttr : Attr<CPred<"::llvm::isa<::mlir::LocationAttr>($_self)">,
@@ -743,6 +744,8 @@ class ConfinedAttr<Attr attr, list<AttrConstraint> constraints> : Attr<
let isOptional = attr.isOptional;
let baseAttr = attr;
+
+ list<AttrConstraint> attrConstraints = constraints;
}
// An AttrConstraint that holds if all attr constraints specified in
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
index 4fea3d8576e9ab..1ba84a5d3683d4 100644
--- a/mlir/test/tblgen-to-irdl/TestDialect.td
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -13,6 +13,10 @@ class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
let mnemonic = typeMnemonic;
}
+class Test_Attr<string name, string attrMnemonic> : AttrDef<Test_Dialect, name> {
+ let mnemonic = attrMnemonic;
+}
+
class Test_Op<string mnemonic, list<Trait> traits = []>
: Op<Test_Dialect, mnemonic, traits>;
@@ -22,6 +26,8 @@ def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
// CHECK: irdl.type @"!singleton_c"
def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}
+// CHECK: irdl.attribute @"#test"
+def Test_TestAttr : Test_Attr<"Test", "test"> {}
// Check that AllOfType is converted correctly.
@@ -45,6 +51,17 @@ def Test_AnyOp : Test_Op<"any"> {
// CHECK-NEXT: irdl.operands(%[[v0]])
// CHECK-NEXT: }
+// Check attributes are converted correctly.
+def Test_AttributesOp : Test_Op<"attributes"> {
+ let arguments = (ins I16Attr:$int_attr,
+ Test_TestAttr:$test_attr);
+}
+// CHECK-LABEL: irdl.operation @attributes {
+// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!builtin.integer"
+// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base @test::@"#test"
+// CHECK-NEXT: irdl.attributes {"int_attr" = %[[v0]], "test_attr" = %[[v1]]}
+// CHECK-NEXT: }
+
// Check confined types are converted correctly.
def Test_ConfinedOp : Test_Op<"confined"> {
let arguments = (ins ConfinedType<AnyType, [CPred<"::llvm::isa<::mlir::TensorType>($_self)">]>:$tensor,
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index 45957bafc378e3..d0a3552fb123da 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -74,8 +74,14 @@ Value typeToConstraint(OpBuilder &builder, Type type) {
return op.getOutput();
}
-std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
+Value baseToConstraint(OpBuilder &builder, StringRef baseClass) {
+ MLIRContext *ctx = builder.getContext();
+ auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+ StringAttr::get(ctx, baseClass));
+ return op.getOutput();
+}
+std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
if (predRec.isSubClassOf("I")) {
auto width = predRec.getValueAsInt("bitwidth");
return IntegerType::get(ctx, width, IntegerType::Signless);
@@ -164,12 +170,12 @@ std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
return std::nullopt;
}
-Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
const Record &predRec = constraint.getDef();
if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
- return createConstraint(builder, predRec.getValueAsDef("baseType"));
+ return createTypeConstraint(builder, predRec.getValueAsDef("baseType"));
if (predRec.getName() == "AnyType") {
auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
@@ -196,7 +202,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
std::vector<Value> constraints;
for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
- createConstraint(builder, tblgen::Constraint(child)));
+ createTypeConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
@@ -206,7 +212,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
std::vector<Value> constraints;
for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
- createConstraint(builder, tblgen::Constraint(child)));
+ createTypeConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
@@ -241,7 +247,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
// Confined type
if (predRec.isSubClassOf("ConfinedType")) {
std::vector<Value> constraints;
- constraints.push_back(createConstraint(
+ constraints.push_back(createTypeConstraint(
builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
@@ -253,6 +259,85 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
return createPredicate(builder, constraint.getPredicate());
}
+Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+ MLIRContext *ctx = builder.getContext();
+ const Record &predRec = constraint.getDef();
+
+ if (predRec.isSubClassOf("DefaultValuedAttr") ||
+ predRec.isSubClassOf("DefaultValuedOptionalAttr") ||
+ predRec.isSubClassOf("OptionalAttr")) {
+ return createAttrConstraint(builder, predRec.getValueAsDef("baseAttr"));
+ }
+
+ if (predRec.isSubClassOf("ConfinedAttr")) {
+ std::vector<Value> constraints;
+ constraints.push_back(createAttrConstraint(
+ builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr"))));
+ for (Record *child : predRec.getValueAsListOfDefs("attrConstraints")) {
+ constraints.push_back(createPredicate(
+ builder, tblgen::Pred(child->getValueAsDef("predicate"))));
+ }
+ auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AnyAttrOf")) {
+ std::vector<Value> constraints;
+ for (Record *child : predRec.getValueAsListOfDefs("allowedAttributes")) {
+ constraints.push_back(
+ createAttrConstraint(builder, tblgen::Constraint(child)));
+ }
+ auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+
+ if (predRec.getName() == "AnyAttr") {
+ auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AnyIntegerAttrBase") ||
+ predRec.isSubClassOf("SignlessIntegerAttrBase") ||
+ predRec.isSubClassOf("SignedIntegerAttrBase") ||
+ predRec.isSubClassOf("UnsignedIntegerAttrBase") ||
+ predRec.isSubClassOf("BoolAttr")) {
+ return baseToConstraint(builder, "!builtin.integer");
+ }
+
+ if (predRec.isSubClassOf("FloatAttrBase")) {
+ return baseToConstraint(builder, "!builtin.float");
+ }
+
+ if (predRec.isSubClassOf("StringBasedAttr")) {
+ return baseToConstraint(builder, "!builtin.string");
+ }
+
+ if (predRec.getName() == "UnitAttr") {
+ auto op =
+ builder.create<irdl::IsOp>(UnknownLoc::get(ctx), UnitAttr::get(ctx));
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AttrDef")) {
+ auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
+ if (dialect == selectedDialect) {
+ std::string combined = ("#" + predRec.getValueAsString("mnemonic")).str();
+ SmallVector<FlatSymbolRefAttr> nested = {SymbolRefAttr::get(ctx, combined)
+
+ };
+ auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
+ auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
+ return op.getOutput();
+ }
+ std::string typeName = ("#" + predRec.getValueAsString("attrName")).str();
+ auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+ StringAttr::get(ctx, typeName));
+ return op.getOutput();
+ }
+
+ return createPredicate(builder, constraint.getPredicate());
+}
+
/// Returns the name of the operation without the dialect prefix.
static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
StringRef opName = tblgenOp.getDef().getValueAsString("opName");
@@ -265,6 +350,12 @@ static StringRef getTypeName(tblgen::TypeDef &tblgenType) {
return opName;
}
+/// Returns the name of the attr without the dialect prefix.
+static StringRef getAttrName(tblgen::AttrDef &tblgenType) {
+ StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
+ return opName;
+}
+
/// Extract an operation to IRDL.
irdl::OperationOp createIRDLOperation(OpBuilder &builder,
tblgen::Operator &tblgenOp) {
@@ -282,7 +373,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
SmallVector<Value> operands;
SmallVector<irdl::VariadicityAttr> variadicity;
for (const NamedTypeConstraint &namedCons : namedCons) {
- auto operand = createConstraint(consBuilder, namedCons.constraint);
+ auto operand = createTypeConstraint(consBuilder, namedCons.constraint);
operands.push_back(operand);
irdl::VariadicityAttr var;
@@ -304,6 +395,15 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
auto [operands, operandVariadicity] = getValues(tblgenOp.getOperands());
auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
+ SmallVector<Value> attributes;
+ SmallVector<Attribute> attrNames;
+ for (auto namedAttr : tblgenOp.getAttributes()) {
+ if (namedAttr.attr.isOptional())
+ continue;
+ attributes.push_back(createAttrConstraint(consBuilder, namedAttr.attr));
+ attrNames.push_back(StringAttr::get(ctx, namedAttr.name));
+ }
+
// Create the operands and results operations.
if (!operands.empty())
consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
@@ -311,6 +411,9 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
if (!results.empty())
consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
resultVariadicity);
+ if (!attributes.empty())
+ consBuilder.create<irdl::AttributesOp>(UnknownLoc::get(ctx), attributes,
+ ArrayAttr::get(ctx, attrNames));
return op;
}
@@ -328,6 +431,20 @@ irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
return op;
}
+irdl::AttributeOp createIRDLAttr(OpBuilder &builder,
+ tblgen::AttrDef &tblgenAttr) {
+ MLIRContext *ctx = builder.getContext();
+ StringRef attrName = getAttrName(tblgenAttr);
+ std::string combined = ("#" + attrName).str();
+
+ irdl::AttributeOp op = builder.create<irdl::AttributeOp>(
+ UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
+
+ op.getBody().emplaceBlock();
+
+ return op;
+}
+
static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
MLIRContext *ctx = builder.getContext();
return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
@@ -358,6 +475,14 @@ static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper,
createIRDLType(builder, tblgenType);
}
+ for (const Record *attr :
+ recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef")) {
+ tblgen::AttrDef tblgenAttr(attr);
+ if (tblgenAttr.getDialect().getName() != selectedDialect)
+ continue;
+ createIRDLAttr(builder, tblgenAttr);
+ }
+
for (const Record *def :
recordKeeper.getAllDerivedDefinitionsIfDefined("Op")) {
tblgen::Operator tblgenOp(def);
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
Adds the ability to export attributes from the dialect and attributes of operations in the dialect
@math-fehr