-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] [tblgen-to-irdl] Refactor tblgen-to-irdl script and support more types #105505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] [tblgen-to-irdl] Refactor tblgen-to-irdl script and support more types #105505
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-ods Author: Alex Rice (alexarice) ChangesRefactors the tblgen-to-irdl script slightly and adds support for
Also doesn't add the operand and result ops if they are empty. I could potentially split this into smaller PRs if that'd be helpful (refactor + integer/float/complex, confined type, optional operand/result). @math-fehr Full diff: https://github.com/llvm/llvm-project/pull/105505.diff 4 Files Affected:
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4536d781ef674f..0e076413d0d9f3 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -198,7 +198,10 @@ class AllOfType<list<Type> allowedTypeList, string summary = "",
class ConfinedType<Type type, list<Pred> predicates, string summary = "",
string cppType = type.cppType> : Type<
And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
- summary, cppType>;
+ summary, cppType> {
+ Type baseType = type;
+ list<Pred> predicateList = predicates;
+}
// Integer types.
diff --git a/mlir/test/tblgen-to-irdl/CMathDialect.td b/mlir/test/tblgen-to-irdl/CMathDialect.td
index 5b9e756727cb36..454543e074c489 100644
--- a/mlir/test/tblgen-to-irdl/CMathDialect.td
+++ b/mlir/test/tblgen-to-irdl/CMathDialect.td
@@ -25,7 +25,6 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
// CHECK: irdl.operation @identity {
// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
-// CHECK-NEXT: irdl.operands()
// CHECK-NEXT: irdl.results(%0)
// CHECK-NEXT: }
def CMath_IdentityOp : CMath_Op<"identity"> {
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
index fc40da527db00a..a86dcb5b3b66e2 100644
--- a/mlir/test/tblgen-to-irdl/TestDialect.td
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -28,9 +28,8 @@ def Test_AndOp : Test_Op<"and"> {
// CHECK-LABEL: irdl.operation @and {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.any
-// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
+// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
// CHECK-NEXT: irdl.operands(%[[v2]])
-// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }
@@ -41,9 +40,37 @@ def Test_AnyOp : Test_Op<"any"> {
// CHECK-LABEL: irdl.operation @any {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.any
// CHECK-NEXT: irdl.operands(%[[v0]])
-// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }
+// Check confined types are converted correctly.
+def Test_ConfinedOp : Test_Op<"confined"> {
+ let arguments = (ins ConfinedType<I32, [IntNonNegative.predicate]>:$confined,
+ ConfinedType<I8, [And<[IntMinValue<1>.predicate, IntMaxValue<2>.predicate]>]>:$bounded);
+}
+// CHECK-LABEL: irdl.operation @confined {
+// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.is i32
+// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.c_pred "{{.*}}"
+// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
+// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.is i8
+// CHECK-NEXT: %[[v4:[^ ]*]] = irdl.c_pred "{{.*}}"
+// CHECK-NEXT: %[[v5:[^ ]*]] = irdl.c_pred "{{.*}}"
+// CHECK-NEXT: %[[v6:[^ ]*]] = irdl.all_of(%[[v4]], %[[v5]])
+// CHECK-NEXT: %[[v7:[^ ]*]] = irdl.all_of(%[[v3]], %[[v6]])
+// CHECK-NEXT: irdl.operands(%[[v2]], %[[v7]])
+// CHECK-NEXT: }
+
+def Test_Integers : Test_Op<"integers"> {
+ let arguments = (ins AnyI8:$any_int,
+ AnyInteger:$any_integer);
+}
+// CHECK-LABEL: irdl.operation @integers {
+// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.is i8
+// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.is si8
+// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.is ui8
+// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
+// CHECK-NEXT: %[[v4:[^ ]*]] = irdl.base "!builtin.integer"
+// CHECK-NEXT: irdl.operands(%[[v3]], %[[v4]])
+// CHECK-NEXT: }
// Check that AnyTypeOf is converted correctly.
def Test_OrOp : Test_Op<"or"> {
@@ -53,11 +80,30 @@ def Test_OrOp : Test_Op<"or"> {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
-// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
+// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
// CHECK-NEXT: irdl.operands(%[[v3]])
-// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }
+// Check that various types are converted correctly.
+def Test_TypesOp : Test_Op<"types"> {
+ let arguments = (ins I32:$a,
+ SI64:$b,
+ UI8:$c,
+ Index:$d,
+ F32:$e,
+ NoneType:$f,
+ Complex<F8E4M3FN>);
+}
+// CHECK-LABEL: irdl.operation @types {
+// CHECK-NEXT: %{{.*}} = irdl.is i32
+// CHECK-NEXT: %{{.*}} = irdl.is si64
+// CHECK-NEXT: %{{.*}} = irdl.is ui8
+// CHECK-NEXT: %{{.*}} = irdl.is index
+// CHECK-NEXT: %{{.*}} = irdl.is f32
+// CHECK-NEXT: %{{.*}} = irdl.is none
+// CHECK-NEXT: %{{.*}} = irdl.is complex<f8E4M3FN>
+// CHECK-NEXT: irdl.operands({{.*}})
+// CHECK-NEXT: }
// Check that variadics and optionals are converted correctly.
def Test_VariadicityOp : Test_Op<"variadicity"> {
@@ -70,5 +116,4 @@ def Test_VariadicityOp : Test_Op<"variadicity"> {
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
// CHECK-NEXT: irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
-// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index a55f3539f31db0..181d02c6608bdb 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -39,6 +39,130 @@ llvm::cl::opt<std::string>
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::Required);
+Value createPredicate(OpBuilder &builder, tblgen::Pred pred) {
+ MLIRContext *ctx = builder.getContext();
+
+ if (pred.isCombined()) {
+ auto combiner = pred.getDef().getValueAsDef("kind")->getName();
+ if (combiner == "PredCombinerAnd" || combiner == "PredCombinerOr") {
+ std::vector<Value> constraints;
+ for (auto *child : pred.getDef().getValueAsListOfDefs("children")) {
+ constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
+ }
+ if (combiner == "PredCombinerAnd") {
+ auto op =
+ builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+ auto op =
+ builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+ }
+
+ std::string condition = pred.getCondition();
+ // Build a CPredOp to match the C constraint built.
+ irdl::CPredOp op = builder.create<irdl::CPredOp>(
+ UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
+ return op;
+}
+
+Value typeToConstraint(OpBuilder &builder, MLIRContext *ctx, Type type) {
+ auto op =
+ builder.create<irdl::IsOp>(UnknownLoc::get(ctx), TypeAttr::get(type));
+ return op.getOutput();
+}
+
+std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
+
+ if (predRec.isSubClassOf("I")) {
+ auto width = predRec.getValueAsInt("bitwidth");
+ return IntegerType::get(ctx, width, IntegerType::Signless);
+ }
+
+ if (predRec.isSubClassOf("SI")) {
+ auto width = predRec.getValueAsInt("bitwidth");
+ return IntegerType::get(ctx, width, IntegerType::Signed);
+ }
+
+ if (predRec.isSubClassOf("UI")) {
+ auto width = predRec.getValueAsInt("bitwidth");
+ return IntegerType::get(ctx, width, IntegerType::Unsigned);
+ }
+
+ // Index type
+ if (predRec.getName() == "Index") {
+ return IndexType::get(ctx);
+ }
+
+ // Float types
+ if (predRec.isSubClassOf("F")) {
+ auto width = predRec.getValueAsInt("bitwidth");
+ switch (width) {
+ case 16:
+ return FloatType::getF16(ctx);
+ case 32:
+ return FloatType::getF32(ctx);
+ case 64:
+ return FloatType::getF64(ctx);
+ case 80:
+ return FloatType::getF80(ctx);
+ case 128:
+ return FloatType::getF128(ctx);
+ }
+ }
+
+ if (predRec.getName() == "NoneType") {
+ return NoneType::get(ctx);
+ }
+
+ if (predRec.getName() == "BF16") {
+ return FloatType::getBF16(ctx);
+ }
+
+ if (predRec.getName() == "TF32") {
+ return FloatType::getTF32(ctx);
+ }
+
+ if (predRec.getName() == "F8E4M3FN") {
+ return FloatType::getFloat8E4M3FN(ctx);
+ }
+
+ if (predRec.getName() == "F8E5M2") {
+ return FloatType::getFloat8E5M2(ctx);
+ }
+
+ if (predRec.getName() == "F8E4M3") {
+ return FloatType::getFloat8E4M3(ctx);
+ }
+
+ if (predRec.getName() == "F8E4M3FNUZ") {
+ return FloatType::getFloat8E4M3FNUZ(ctx);
+ }
+
+ if (predRec.getName() == "F8E4M3B11FNUZ") {
+ return FloatType::getFloat8E4M3B11FNUZ(ctx);
+ }
+
+ if (predRec.getName() == "F8E5M2FNUZ") {
+ return FloatType::getFloat8E5M2FNUZ(ctx);
+ }
+
+ if (predRec.getName() == "F8E3M4") {
+ return FloatType::getFloat8E3M4(ctx);
+ }
+
+ if (predRec.isSubClassOf("Complex")) {
+ const Record *elementRec = predRec.getValueAsDef("elementType");
+ auto elementType = recordToType(ctx, *elementRec);
+ if (elementType.has_value()) {
+ return ComplexType::get(elementType.value());
+ }
+ }
+
+ return std::nullopt;
+}
+
Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
const Record &predRec = constraint.getDef();
@@ -78,11 +202,45 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
return op.getOutput();
}
- std::string condition = constraint.getPredicate().getCondition();
- // Build a CPredOp to match the C constraint built.
- irdl::CPredOp op = builder.create<irdl::CPredOp>(
- UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
- return op;
+ // Integer types
+ if (predRec.getName() == "AnyInteger") {
+ auto op = builder.create<irdl::BaseOp>(
+ UnknownLoc::get(ctx), StringAttr::get(ctx, "!builtin.integer"));
+ return op.getOutput();
+ }
+
+ if (predRec.isSubClassOf("AnyI")) {
+ auto width = predRec.getValueAsInt("bitwidth");
+ std::vector<Value> types = {
+ typeToConstraint(builder, ctx,
+ IntegerType::get(ctx, width, IntegerType::Signless)),
+ typeToConstraint(builder, ctx,
+ IntegerType::get(ctx, width, IntegerType::Signed)),
+ typeToConstraint(builder, ctx,
+ IntegerType::get(ctx, width, IntegerType::Unsigned))};
+ auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), types);
+ return op.getOutput();
+ }
+
+ auto type = recordToType(ctx, predRec);
+
+ if (type.has_value()) {
+ return typeToConstraint(builder, ctx, type.value());
+ }
+
+ // Confined type
+ if (predRec.isSubClassOf("ConfinedType")) {
+ std::vector<Value> constraints;
+ constraints.push_back(createConstraint(
+ builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
+ for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
+ constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
+ }
+ auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+ return op.getOutput();
+ }
+
+ return createPredicate(builder, constraint.getPredicate());
}
/// Returns the name of the operation without the dialect prefix.
@@ -131,10 +289,12 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
// Create the operands and results operations.
- consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
- operandVariadicity);
- consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
- resultVariadicity);
+ if (!operands.empty())
+ consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
+ operandVariadicity);
+ if (!results.empty())
+ consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
+ resultVariadicity);
return op;
}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
930b9a5
to
1d1851c
Compare
1d1851c
to
54aa7d0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot!
I just added one comment on a test, but otherwise I'm fine merging it as-is if you think that's irrelevant!
Refactors the tblgen-to-irdl script slightly and adds support for
Also doesn't add the operand and result ops if they are empty.
I could potentially split this into smaller PRs if that'd be helpful (refactor + integer/float/complex, confined type, optional operand/result).
@math-fehr