Skip to content

[mlir][irdl] Add support for basic structural constraints in tblgen-to-irdl #82862

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 5, 2024

Conversation

math-fehr
Copy link
Contributor

Adds tblgen-to-irdl support for TypeDef, AnyType, AnyTypeOf, and AllOfType ODS constraints.
This is done by introspecting the TableGen constructs directly.

For instance, shape.add now looks like:

    irdl.operation @add {
      %0 = irdl.base "!shape.size" 
      %1 = irdl.c_pred "(::llvm::isa<::mlir::IndexType>($_self))" 
      %2 = irdl.any_of(%0, %1) 
      %3 = irdl.base "!shape.size" 
      %4 = irdl.c_pred "(::llvm::isa<::mlir::IndexType>($_self))" 
      %5 = irdl.any_of(%3, %4) 
      %6 = irdl.base "!shape.size" 
      %7 = irdl.c_pred "(::llvm::isa<::mlir::IndexType>($_self))" 
      %8 = irdl.any_of(%6, %7) 
      irdl.operands(%2, %5)
      irdl.results(%8)
    }

instead of previously

    irdl.operation @add {
      %0 = irdl.c_pred "((::llvm::isa<::mlir::shape::SizeType>($_self))) || ((::llvm::isa<::mlir::IndexType>($_self)))" 
      %1 = irdl.c_pred "((::llvm::isa<::mlir::shape::SizeType>($_self))) || ((::llvm::isa<::mlir::IndexType>($_self)))" 
      %2 = irdl.c_pred "((::llvm::isa<::mlir::shape::SizeType>($_self))) || ((::llvm::isa<::mlir::IndexType>($_self)))" 
      irdl.operands(%0, %1)
      irdl.results(%2)
    }

@math-fehr math-fehr self-assigned this Feb 24, 2024
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:ods labels Feb 24, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 24, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-irdl

Author: Fehr Mathieu (math-fehr)

Changes

Adds tblgen-to-irdl support for TypeDef, AnyType, AnyTypeOf, and AllOfType ODS constraints.
This is done by introspecting the TableGen constructs directly.

For instance, shape.add now looks like:

    irdl.operation @<!-- -->add {
      %0 = irdl.base "!shape.size" 
      %1 = irdl.c_pred "(::llvm::isa&lt;::mlir::IndexType&gt;($_self))" 
      %2 = irdl.any_of(%0, %1) 
      %3 = irdl.base "!shape.size" 
      %4 = irdl.c_pred "(::llvm::isa&lt;::mlir::IndexType&gt;($_self))" 
      %5 = irdl.any_of(%3, %4) 
      %6 = irdl.base "!shape.size" 
      %7 = irdl.c_pred "(::llvm::isa&lt;::mlir::IndexType&gt;($_self))" 
      %8 = irdl.any_of(%6, %7) 
      irdl.operands(%2, %5)
      irdl.results(%8)
    }

instead of previously

    irdl.operation @<!-- -->add {
      %0 = irdl.c_pred "((::llvm::isa&lt;::mlir::shape::SizeType&gt;($_self))) || ((::llvm::isa&lt;::mlir::IndexType&gt;($_self)))" 
      %1 = irdl.c_pred "((::llvm::isa&lt;::mlir::shape::SizeType&gt;($_self))) || ((::llvm::isa&lt;::mlir::IndexType&gt;($_self)))" 
      %2 = irdl.c_pred "((::llvm::isa&lt;::mlir::shape::SizeType&gt;($_self))) || ((::llvm::isa&lt;::mlir::IndexType&gt;($_self)))" 
      irdl.operands(%0, %1)
      irdl.results(%2)
    }

Full diff: https://github.com/llvm/llvm-project/pull/82862.diff

4 Files Affected:

  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+13-9)
  • (modified) mlir/test/tblgen-to-irdl/CMathDialect.td (+6-6)
  • (added) mlir/test/tblgen-to-irdl/TestDialect.td (+74)
  • (modified) mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp (+41-7)
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 03180a687523bf..af4f13dc09360d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -168,24 +168,28 @@ def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
       BuildableType<"$_builder.getType<::mlir::NoneType>()">;
 
 // Any type from the given list
-class AnyTypeOf<list<Type> allowedTypes, string summary = "",
+class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
                 string cppClassName = "::mlir::Type"> : Type<
     // Satisfy any of the allowed types' conditions.
-    Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
+    Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
     !if(!eq(summary, ""),
-        !interleave(!foreach(t, allowedTypes, t.summary), " or "),
+        !interleave(!foreach(t, allowedTypeList, t.summary), " or "),
         summary),
-    cppClassName>;
+    cppClassName> {
+  list<Type> allowedTypes = allowedTypeList;
+}
 
 // A type that satisfies the constraints of all given types.
-class AllOfType<list<Type> allowedTypes, string summary = "",
+class AllOfType<list<Type> allowedTypeList, string summary = "",
                 string cppClassName = "::mlir::Type"> : Type<
-    // Satisfy all of the allowedf types' conditions.
-    And<!foreach(allowedType, allowedTypes, allowedType.predicate)>,
+    // Satisfy all of the allowed types' conditions.
+    And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>,
     !if(!eq(summary, ""),
-        !interleave(!foreach(t, allowedTypes, t.summary), " and "),
+        !interleave(!foreach(t, allowedTypeList, t.summary), " and "),
         summary),
-    cppClassName>;
+    cppClassName> {
+  list<Type> allowedTypes = allowedTypeList;
+}
 
 // A type that satisfies additional predicates.
 class ConfinedType<Type type, list<Pred> predicates, string summary = "",
diff --git a/mlir/test/tblgen-to-irdl/CMathDialect.td b/mlir/test/tblgen-to-irdl/CMathDialect.td
index 57ae8afbba5eeb..5b9e756727cb36 100644
--- a/mlir/test/tblgen-to-irdl/CMathDialect.td
+++ b/mlir/test/tblgen-to-irdl/CMathDialect.td
@@ -24,7 +24,7 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
 }
 
 // CHECK:      irdl.operation @identity {
-// CHECK-NEXT:   %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))" 
+// CHECK-NEXT:   %0 = irdl.base "!cmath.complex"
 // CHECK-NEXT:   irdl.operands()
 // CHECK-NEXT:   irdl.results(%0)
 // CHECK-NEXT: }
@@ -33,9 +33,9 @@ def CMath_IdentityOp : CMath_Op<"identity"> {
 }
 
 // CHECK:      irdl.operation @mul {
-// CHECK-NEXT:   %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))" 
-// CHECK-NEXT:   %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))" 
-// CHECK-NEXT:   %2 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))" 
+// CHECK-NEXT:   %0 = irdl.base "!cmath.complex"
+// CHECK-NEXT:   %1 = irdl.base "!cmath.complex"
+// CHECK-NEXT:   %2 = irdl.base "!cmath.complex"
 // CHECK-NEXT:   irdl.operands(%0, %1)
 // CHECK-NEXT:   irdl.results(%2)
 // CHECK-NEXT: }
@@ -45,8 +45,8 @@ def CMath_MulOp : CMath_Op<"mul"> {
 }
 
 // CHECK:      irdl.operation @norm {
-// CHECK-NEXT:   %0 = irdl.c_pred "(true)" 
-// CHECK-NEXT:   %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))" 
+// CHECK-NEXT:   %0 = irdl.any
+// CHECK-NEXT:   %1 = irdl.base "!cmath.complex"
 // CHECK-NEXT:   irdl.operands(%0)
 // CHECK-NEXT:   irdl.results(%1)
 // CHECK-NEXT: }
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
new file mode 100644
index 00000000000000..fc40da527db00a
--- /dev/null
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -0,0 +1,74 @@
+// RUN: tblgen-to-irdl %s -I=%S/../../include --gen-dialect-irdl-defs --dialect=test | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/AttrTypeBase.td"
+
+// CHECK-LABEL: irdl.dialect @test {
+def Test_Dialect : Dialect {
+  let name = "test";
+}
+
+class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
+: TypeDef<Test_Dialect, name, traits> {
+  let mnemonic = typeMnemonic;
+}
+
+class Test_Op<string mnemonic, list<Trait> traits = []>
+    : Op<Test_Dialect, mnemonic, traits>;
+
+def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
+def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
+def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}
+
+
+// Check that AllOfType is converted correctly.
+def Test_AndOp : Test_Op<"and"> {
+  let arguments = (ins AllOfType<[Test_SingletonAType, AnyType]>:$in);
+}
+// CHECK-LABEL: irdl.operation @and {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.any
+// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]]) 
+// CHECK-NEXT:    irdl.operands(%[[v2]])
+// CHECK-NEXT:    irdl.results()
+// CHECK-NEXT:  }
+
+
+// Check that AnyType is converted correctly.
+def Test_AnyOp : Test_Op<"any"> {
+  let arguments = (ins AnyType:$in);
+}
+// CHECK-LABEL: irdl.operation @any {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.any
+// CHECK-NEXT:    irdl.operands(%[[v0]])
+// CHECK-NEXT:    irdl.results()
+// CHECK-NEXT:  }
+
+
+// Check that AnyTypeOf is converted correctly.
+def Test_OrOp : Test_Op<"or"> {
+  let arguments = (ins AnyTypeOf<[Test_SingletonAType, Test_SingletonBType, Test_SingletonCType]>:$in);
+}
+// CHECK-LABEL: irdl.operation @or {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
+// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
+// CHECK-NEXT:    %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]]) 
+// CHECK-NEXT:    irdl.operands(%[[v3]])
+// CHECK-NEXT:    irdl.results()
+// CHECK-NEXT:  }
+
+
+// Check that variadics and optionals are converted correctly.
+def Test_VariadicityOp : Test_Op<"variadicity"> {
+  let arguments = (ins Variadic<Test_SingletonAType>:$variadic,
+                       Optional<Test_SingletonBType>:$optional,
+                       Test_SingletonCType:$required);
+}
+// CHECK-LABEL: irdl.operation @variadicity {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
+// CHECK-NEXT:    %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
+// CHECK-NEXT:    irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
+// CHECK-NEXT:    irdl.results()
+// CHECK-NEXT:  }
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index ba5bf4d9d4abbc..a55f3539f31db0 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -39,15 +39,49 @@ llvm::cl::opt<std::string>
     selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
                     llvm::cl::cat(dialectGenCat), llvm::cl::Required);
 
-irdl::CPredOp createConstraint(OpBuilder &builder,
-                               NamedTypeConstraint namedConstraint) {
+Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   MLIRContext *ctx = builder.getContext();
-  // Build the constraint as a string.
-  std::string constraint =
-      namedConstraint.constraint.getPredicate().getCondition();
+  const Record &predRec = constraint.getDef();
+
+  if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
+    return createConstraint(builder, predRec.getValueAsDef("baseType"));
+
+  if (predRec.getName() == "AnyType") {
+    auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("TypeDef")) {
+    std::string typeName = ("!" + predRec.getValueAsString("typeName")).str();
+    auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+                                           StringAttr::get(ctx, typeName));
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AnyTypeOf")) {
+    std::vector<Value> constraints;
+    for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
+      constraints.push_back(
+          createConstraint(builder, tblgen::Constraint(child)));
+    }
+    auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AllOfType")) {
+    std::vector<Value> constraints;
+    for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
+      constraints.push_back(
+          createConstraint(builder, tblgen::Constraint(child)));
+    }
+    auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+    return op.getOutput();
+  }
+
+  std::string condition = constraint.getPredicate().getCondition();
   // Build a CPredOp to match the C constraint built.
   irdl::CPredOp op = builder.create<irdl::CPredOp>(
-      UnknownLoc::get(ctx), StringAttr::get(ctx, constraint));
+      UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
   return op;
 }
 
@@ -74,7 +108,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
     SmallVector<Value> operands;
     SmallVector<irdl::VariadicityAttr> variadicity;
     for (const NamedTypeConstraint &namedCons : namedCons) {
-      auto operand = createConstraint(consBuilder, namedCons);
+      auto operand = createConstraint(consBuilder, namedCons.constraint);
       operands.push_back(operand);
 
       irdl::VariadicityAttr var;

@math-fehr math-fehr merged commit a64975f into llvm:main Mar 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants