Skip to content

[mlir] [tblgen-to-irdl] Refactor tblgen-to-irdl script and support more types #105505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 11, 2024

Conversation

alexarice
Copy link
Contributor

Refactors the tblgen-to-irdl script slightly and adds support for

  • Various integer types
  • Various Float types
  • Confined types
  • Complex types (with fixed element type)

Also doesn't add the operand and result ops if they are empty.

I could potentially split this into smaller PRs if that'd be helpful (refactor + integer/float/complex, confined type, optional operand/result).

@math-fehr

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:ods labels Aug 21, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 21, 2024

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-ods

Author: Alex Rice (alexarice)

Changes

Refactors the tblgen-to-irdl script slightly and adds support for

  • Various integer types
  • Various Float types
  • Confined types
  • Complex types (with fixed element type)

Also doesn't add the operand and result ops if they are empty.

I could potentially split this into smaller PRs if that'd be helpful (refactor + integer/float/complex, confined type, optional operand/result).

@math-fehr


Full diff: https://github.com/llvm/llvm-project/pull/105505.diff

4 Files Affected:

  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4-1)
  • (modified) mlir/test/tblgen-to-irdl/CMathDialect.td (-1)
  • (modified) mlir/test/tblgen-to-irdl/TestDialect.td (+51-6)
  • (modified) mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp (+169-9)
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4536d781ef674f..0e076413d0d9f3 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -198,7 +198,10 @@ class AllOfType<list<Type> allowedTypeList, string summary = "",
 class ConfinedType<Type type, list<Pred> predicates, string summary = "",
                    string cppType = type.cppType> : Type<
     And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
-    summary, cppType>;
+    summary, cppType> {
+    Type baseType = type;
+    list<Pred> predicateList = predicates;
+}
 
 // Integer types.
 
diff --git a/mlir/test/tblgen-to-irdl/CMathDialect.td b/mlir/test/tblgen-to-irdl/CMathDialect.td
index 5b9e756727cb36..454543e074c489 100644
--- a/mlir/test/tblgen-to-irdl/CMathDialect.td
+++ b/mlir/test/tblgen-to-irdl/CMathDialect.td
@@ -25,7 +25,6 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
 
 // CHECK:      irdl.operation @identity {
 // CHECK-NEXT:   %0 = irdl.base "!cmath.complex"
-// CHECK-NEXT:   irdl.operands()
 // CHECK-NEXT:   irdl.results(%0)
 // CHECK-NEXT: }
 def CMath_IdentityOp : CMath_Op<"identity"> {
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
index fc40da527db00a..a86dcb5b3b66e2 100644
--- a/mlir/test/tblgen-to-irdl/TestDialect.td
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -28,9 +28,8 @@ def Test_AndOp : Test_Op<"and"> {
 // CHECK-LABEL: irdl.operation @and {
 // CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
 // CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.any
-// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]]) 
+// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
 // CHECK-NEXT:    irdl.operands(%[[v2]])
-// CHECK-NEXT:    irdl.results()
 // CHECK-NEXT:  }
 
 
@@ -41,9 +40,37 @@ def Test_AnyOp : Test_Op<"any"> {
 // CHECK-LABEL: irdl.operation @any {
 // CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.any
 // CHECK-NEXT:    irdl.operands(%[[v0]])
-// CHECK-NEXT:    irdl.results()
 // CHECK-NEXT:  }
 
+// Check confined types are converted correctly.
+def Test_ConfinedOp : Test_Op<"confined"> {
+  let arguments = (ins ConfinedType<I32, [IntNonNegative.predicate]>:$confined,
+                       ConfinedType<I8, [And<[IntMinValue<1>.predicate, IntMaxValue<2>.predicate]>]>:$bounded);
+}
+// CHECK-LABEL: irdl.operation @confined {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.is i32
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.c_pred "{{.*}}"
+// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
+// CHECK-NEXT:    %[[v3:[^ ]*]] = irdl.is i8
+// CHECK-NEXT:    %[[v4:[^ ]*]] = irdl.c_pred "{{.*}}"
+// CHECK-NEXT:    %[[v5:[^ ]*]] = irdl.c_pred "{{.*}}"
+// CHECK-NEXT:    %[[v6:[^ ]*]] = irdl.all_of(%[[v4]], %[[v5]])
+// CHECK-NEXT:    %[[v7:[^ ]*]] = irdl.all_of(%[[v3]], %[[v6]])
+// CHECK-NEXT:    irdl.operands(%[[v2]], %[[v7]])
+// CHECK-NEXT:  }
+
+def Test_Integers : Test_Op<"integers"> {
+  let arguments = (ins AnyI8:$any_int,
+                       AnyInteger:$any_integer);
+}
+// CHECK-LABEL: irdl.operation @integers {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.is i8
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.is si8
+// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.is ui8
+// CHECK-NEXT:    %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
+// CHECK-NEXT:    %[[v4:[^ ]*]] = irdl.base "!builtin.integer"
+// CHECK-NEXT:    irdl.operands(%[[v3]], %[[v4]])
+// CHECK-NEXT:  }
 
 // Check that AnyTypeOf is converted correctly.
 def Test_OrOp : Test_Op<"or"> {
@@ -53,11 +80,30 @@ def Test_OrOp : Test_Op<"or"> {
 // CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
 // CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
 // CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
-// CHECK-NEXT:    %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]]) 
+// CHECK-NEXT:    %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
 // CHECK-NEXT:    irdl.operands(%[[v3]])
-// CHECK-NEXT:    irdl.results()
 // CHECK-NEXT:  }
 
+// Check that various types are converted correctly.
+def Test_TypesOp : Test_Op<"types"> {
+  let arguments = (ins I32:$a,
+                       SI64:$b,
+                       UI8:$c,
+                       Index:$d,
+                       F32:$e,
+                       NoneType:$f,
+                       Complex<F8E4M3FN>);
+}
+// CHECK-LABEL: irdl.operation @types {
+// CHECK-NEXT:    %{{.*}} = irdl.is i32
+// CHECK-NEXT:    %{{.*}} = irdl.is si64
+// CHECK-NEXT:    %{{.*}} = irdl.is ui8
+// CHECK-NEXT:    %{{.*}} = irdl.is index
+// CHECK-NEXT:    %{{.*}} = irdl.is f32
+// CHECK-NEXT:    %{{.*}} = irdl.is none
+// CHECK-NEXT:    %{{.*}} = irdl.is complex<f8E4M3FN>
+// CHECK-NEXT:    irdl.operands({{.*}})
+// CHECK-NEXT:  }
 
 // Check that variadics and optionals are converted correctly.
 def Test_VariadicityOp : Test_Op<"variadicity"> {
@@ -70,5 +116,4 @@ def Test_VariadicityOp : Test_Op<"variadicity"> {
 // CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
 // CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
 // CHECK-NEXT:    irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
-// CHECK-NEXT:    irdl.results()
 // CHECK-NEXT:  }
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index a55f3539f31db0..181d02c6608bdb 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -39,6 +39,130 @@ llvm::cl::opt<std::string>
     selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
                     llvm::cl::cat(dialectGenCat), llvm::cl::Required);
 
+Value createPredicate(OpBuilder &builder, tblgen::Pred pred) {
+  MLIRContext *ctx = builder.getContext();
+
+  if (pred.isCombined()) {
+    auto combiner = pred.getDef().getValueAsDef("kind")->getName();
+    if (combiner == "PredCombinerAnd" || combiner == "PredCombinerOr") {
+      std::vector<Value> constraints;
+      for (auto *child : pred.getDef().getValueAsListOfDefs("children")) {
+        constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
+      }
+      if (combiner == "PredCombinerAnd") {
+        auto op =
+            builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+        return op.getOutput();
+      }
+      auto op =
+          builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
+      return op.getOutput();
+    }
+  }
+
+  std::string condition = pred.getCondition();
+  // Build a CPredOp to match the C constraint built.
+  irdl::CPredOp op = builder.create<irdl::CPredOp>(
+      UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
+  return op;
+}
+
+Value typeToConstraint(OpBuilder &builder, MLIRContext *ctx, Type type) {
+  auto op =
+      builder.create<irdl::IsOp>(UnknownLoc::get(ctx), TypeAttr::get(type));
+  return op.getOutput();
+}
+
+std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
+
+  if (predRec.isSubClassOf("I")) {
+    auto width = predRec.getValueAsInt("bitwidth");
+    return IntegerType::get(ctx, width, IntegerType::Signless);
+  }
+
+  if (predRec.isSubClassOf("SI")) {
+    auto width = predRec.getValueAsInt("bitwidth");
+    return IntegerType::get(ctx, width, IntegerType::Signed);
+  }
+
+  if (predRec.isSubClassOf("UI")) {
+    auto width = predRec.getValueAsInt("bitwidth");
+    return IntegerType::get(ctx, width, IntegerType::Unsigned);
+  }
+
+  // Index type
+  if (predRec.getName() == "Index") {
+    return IndexType::get(ctx);
+  }
+
+  // Float types
+  if (predRec.isSubClassOf("F")) {
+    auto width = predRec.getValueAsInt("bitwidth");
+    switch (width) {
+    case 16:
+      return FloatType::getF16(ctx);
+    case 32:
+      return FloatType::getF32(ctx);
+    case 64:
+      return FloatType::getF64(ctx);
+    case 80:
+      return FloatType::getF80(ctx);
+    case 128:
+      return FloatType::getF128(ctx);
+    }
+  }
+
+  if (predRec.getName() == "NoneType") {
+    return NoneType::get(ctx);
+  }
+
+  if (predRec.getName() == "BF16") {
+    return FloatType::getBF16(ctx);
+  }
+
+  if (predRec.getName() == "TF32") {
+    return FloatType::getTF32(ctx);
+  }
+
+  if (predRec.getName() == "F8E4M3FN") {
+    return FloatType::getFloat8E4M3FN(ctx);
+  }
+
+  if (predRec.getName() == "F8E5M2") {
+    return FloatType::getFloat8E5M2(ctx);
+  }
+
+  if (predRec.getName() == "F8E4M3") {
+    return FloatType::getFloat8E4M3(ctx);
+  }
+
+  if (predRec.getName() == "F8E4M3FNUZ") {
+    return FloatType::getFloat8E4M3FNUZ(ctx);
+  }
+
+  if (predRec.getName() == "F8E4M3B11FNUZ") {
+    return FloatType::getFloat8E4M3B11FNUZ(ctx);
+  }
+
+  if (predRec.getName() == "F8E5M2FNUZ") {
+    return FloatType::getFloat8E5M2FNUZ(ctx);
+  }
+
+  if (predRec.getName() == "F8E3M4") {
+    return FloatType::getFloat8E3M4(ctx);
+  }
+
+  if (predRec.isSubClassOf("Complex")) {
+    const Record *elementRec = predRec.getValueAsDef("elementType");
+    auto elementType = recordToType(ctx, *elementRec);
+    if (elementType.has_value()) {
+      return ComplexType::get(elementType.value());
+    }
+  }
+
+  return std::nullopt;
+}
+
 Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   MLIRContext *ctx = builder.getContext();
   const Record &predRec = constraint.getDef();
@@ -78,11 +202,45 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
     return op.getOutput();
   }
 
-  std::string condition = constraint.getPredicate().getCondition();
-  // Build a CPredOp to match the C constraint built.
-  irdl::CPredOp op = builder.create<irdl::CPredOp>(
-      UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
-  return op;
+  // Integer types
+  if (predRec.getName() == "AnyInteger") {
+    auto op = builder.create<irdl::BaseOp>(
+        UnknownLoc::get(ctx), StringAttr::get(ctx, "!builtin.integer"));
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AnyI")) {
+    auto width = predRec.getValueAsInt("bitwidth");
+    std::vector<Value> types = {
+        typeToConstraint(builder, ctx,
+                         IntegerType::get(ctx, width, IntegerType::Signless)),
+        typeToConstraint(builder, ctx,
+                         IntegerType::get(ctx, width, IntegerType::Signed)),
+        typeToConstraint(builder, ctx,
+                         IntegerType::get(ctx, width, IntegerType::Unsigned))};
+    auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), types);
+    return op.getOutput();
+  }
+
+  auto type = recordToType(ctx, predRec);
+
+  if (type.has_value()) {
+    return typeToConstraint(builder, ctx, type.value());
+  }
+
+  // Confined type
+  if (predRec.isSubClassOf("ConfinedType")) {
+    std::vector<Value> constraints;
+    constraints.push_back(createConstraint(
+        builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
+    for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
+      constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
+    }
+    auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+    return op.getOutput();
+  }
+
+  return createPredicate(builder, constraint.getPredicate());
 }
 
 /// Returns the name of the operation without the dialect prefix.
@@ -131,10 +289,12 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
   auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
 
   // Create the operands and results operations.
-  consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
-                                       operandVariadicity);
-  consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
-                                      resultVariadicity);
+  if (!operands.empty())
+    consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
+					 operandVariadicity);
+  if (!results.empty())
+    consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
+					resultVariadicity);
 
   return op;
 }

Copy link

github-actions bot commented Aug 21, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@alexarice alexarice force-pushed the alexarice/tblgen-to-irdl-refactor branch 2 times, most recently from 930b9a5 to 1d1851c Compare August 21, 2024 12:19
@alexarice alexarice force-pushed the alexarice/tblgen-to-irdl-refactor branch from 1d1851c to 54aa7d0 Compare August 21, 2024 13:55
@alexarice alexarice changed the title Refactor tblgen-to-irdl script and support more types [mlir] [tblgen-to-irdl] Refactor tblgen-to-irdl script and support more types Aug 21, 2024
Copy link
Contributor

@math-fehr math-fehr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!
I just added one comment on a test, but otherwise I'm fine merging it as-is if you think that's irrelevant!

@alexarice alexarice requested a review from math-fehr September 3, 2024 10:45
@math-fehr math-fehr merged commit 135bd31 into llvm:main Sep 11, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:ods mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants