Skip to content

[mlir] [tblgen-to-irdl] Add attributes to tblgen-to-irdl script #109633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 27, 2024

Conversation

alexarice
Copy link
Contributor

Adds the ability to export attributes from the dialect and attributes of operations in the dialect

@math-fehr

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:ods labels Sep 23, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 23, 2024

@llvm/pr-subscribers-mlir-ods

Author: Alex Rice (alexarice)

Changes

Adds the ability to export attributes from the dialect and attributes of operations in the dialect

@math-fehr


Full diff: https://github.com/llvm/llvm-project/pull/109633.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/CommonAttrConstraints.td (+3)
  • (modified) mlir/test/tblgen-to-irdl/TestDialect.td (+17)
  • (modified) mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp (+132-7)
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 853fb318c76e71..de5f6797235e3c 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -178,6 +178,7 @@ class AnyAttrOf<list<Attr> allowedAttrs, string summary = "",
         summary)> {
     let returnType = cppType;
     let convertFromStorage = fromStorage;
+    list<Attr> allowedAttributes = allowedAttrs;
 }
 
 def LocationAttr : Attr<CPred<"::llvm::isa<::mlir::LocationAttr>($_self)">,
@@ -743,6 +744,8 @@ class ConfinedAttr<Attr attr, list<AttrConstraint> constraints> : Attr<
   let isOptional = attr.isOptional;
 
   let baseAttr = attr;
+
+  list<AttrConstraint> attrConstraints = constraints;
 }
 
 // An AttrConstraint that holds if all attr constraints specified in
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
index 4fea3d8576e9ab..1ba84a5d3683d4 100644
--- a/mlir/test/tblgen-to-irdl/TestDialect.td
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -13,6 +13,10 @@ class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
   let mnemonic = typeMnemonic;
 }
 
+class Test_Attr<string name, string attrMnemonic> : AttrDef<Test_Dialect, name> {
+  let mnemonic = attrMnemonic;
+}
+
 class Test_Op<string mnemonic, list<Trait> traits = []>
     : Op<Test_Dialect, mnemonic, traits>;
 
@@ -22,6 +26,8 @@ def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
 def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
 // CHECK: irdl.type @"!singleton_c"
 def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}
+// CHECK: irdl.attribute @"#test"
+def Test_TestAttr : Test_Attr<"Test", "test"> {}
 
 
 // Check that AllOfType is converted correctly.
@@ -45,6 +51,17 @@ def Test_AnyOp : Test_Op<"any"> {
 // CHECK-NEXT:    irdl.operands(%[[v0]])
 // CHECK-NEXT:  }
 
+// Check attributes are converted correctly.
+def Test_AttributesOp : Test_Op<"attributes"> {
+  let arguments = (ins I16Attr:$int_attr,
+                       Test_TestAttr:$test_attr);
+}
+// CHECK-LABEL: irdl.operation @attributes {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!builtin.integer"
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base @test::@"#test"
+// CHECK-NEXT:    irdl.attributes {"int_attr" = %[[v0]], "test_attr" = %[[v1]]}
+// CHECK-NEXT:  }
+
 // Check confined types are converted correctly.
 def Test_ConfinedOp : Test_Op<"confined"> {
   let arguments = (ins ConfinedType<AnyType, [CPred<"::llvm::isa<::mlir::TensorType>($_self)">]>:$tensor,
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index 45957bafc378e3..d0a3552fb123da 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -74,8 +74,14 @@ Value typeToConstraint(OpBuilder &builder, Type type) {
   return op.getOutput();
 }
 
-std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
+Value baseToConstraint(OpBuilder &builder, StringRef baseClass) {
+  MLIRContext *ctx = builder.getContext();
+  auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+                                         StringAttr::get(ctx, baseClass));
+  return op.getOutput();
+}
 
+std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
   if (predRec.isSubClassOf("I")) {
     auto width = predRec.getValueAsInt("bitwidth");
     return IntegerType::get(ctx, width, IntegerType::Signless);
@@ -164,12 +170,12 @@ std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
   return std::nullopt;
 }
 
-Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   MLIRContext *ctx = builder.getContext();
   const Record &predRec = constraint.getDef();
 
   if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
-    return createConstraint(builder, predRec.getValueAsDef("baseType"));
+    return createTypeConstraint(builder, predRec.getValueAsDef("baseType"));
 
   if (predRec.getName() == "AnyType") {
     auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
@@ -196,7 +202,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
     std::vector<Value> constraints;
     for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
       constraints.push_back(
-          createConstraint(builder, tblgen::Constraint(child)));
+          createTypeConstraint(builder, tblgen::Constraint(child)));
     }
     auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
     return op.getOutput();
@@ -206,7 +212,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
     std::vector<Value> constraints;
     for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
       constraints.push_back(
-          createConstraint(builder, tblgen::Constraint(child)));
+          createTypeConstraint(builder, tblgen::Constraint(child)));
     }
     auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
     return op.getOutput();
@@ -241,7 +247,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   // Confined type
   if (predRec.isSubClassOf("ConfinedType")) {
     std::vector<Value> constraints;
-    constraints.push_back(createConstraint(
+    constraints.push_back(createTypeConstraint(
         builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
     for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
       constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
@@ -253,6 +259,85 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   return createPredicate(builder, constraint.getPredicate());
 }
 
+Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+  MLIRContext *ctx = builder.getContext();
+  const Record &predRec = constraint.getDef();
+
+  if (predRec.isSubClassOf("DefaultValuedAttr") ||
+      predRec.isSubClassOf("DefaultValuedOptionalAttr") ||
+      predRec.isSubClassOf("OptionalAttr")) {
+    return createAttrConstraint(builder, predRec.getValueAsDef("baseAttr"));
+  }
+
+  if (predRec.isSubClassOf("ConfinedAttr")) {
+    std::vector<Value> constraints;
+    constraints.push_back(createAttrConstraint(
+        builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr"))));
+    for (Record *child : predRec.getValueAsListOfDefs("attrConstraints")) {
+      constraints.push_back(createPredicate(
+          builder, tblgen::Pred(child->getValueAsDef("predicate"))));
+    }
+    auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AnyAttrOf")) {
+    std::vector<Value> constraints;
+    for (Record *child : predRec.getValueAsListOfDefs("allowedAttributes")) {
+      constraints.push_back(
+          createAttrConstraint(builder, tblgen::Constraint(child)));
+    }
+    auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
+    return op.getOutput();
+  }
+
+  if (predRec.getName() == "AnyAttr") {
+    auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AnyIntegerAttrBase") ||
+      predRec.isSubClassOf("SignlessIntegerAttrBase") ||
+      predRec.isSubClassOf("SignedIntegerAttrBase") ||
+      predRec.isSubClassOf("UnsignedIntegerAttrBase") ||
+      predRec.isSubClassOf("BoolAttr")) {
+    return baseToConstraint(builder, "!builtin.integer");
+  }
+
+  if (predRec.isSubClassOf("FloatAttrBase")) {
+    return baseToConstraint(builder, "!builtin.float");
+  }
+
+  if (predRec.isSubClassOf("StringBasedAttr")) {
+    return baseToConstraint(builder, "!builtin.string");
+  }
+
+  if (predRec.getName() == "UnitAttr") {
+    auto op =
+        builder.create<irdl::IsOp>(UnknownLoc::get(ctx), UnitAttr::get(ctx));
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AttrDef")) {
+    auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
+    if (dialect == selectedDialect) {
+      std::string combined = ("#" + predRec.getValueAsString("mnemonic")).str();
+      SmallVector<FlatSymbolRefAttr> nested = {SymbolRefAttr::get(ctx, combined)
+
+      };
+      auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
+      auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
+      return op.getOutput();
+    }
+    std::string typeName = ("#" + predRec.getValueAsString("attrName")).str();
+    auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+                                           StringAttr::get(ctx, typeName));
+    return op.getOutput();
+  }
+
+  return createPredicate(builder, constraint.getPredicate());
+}
+
 /// Returns the name of the operation without the dialect prefix.
 static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
   StringRef opName = tblgenOp.getDef().getValueAsString("opName");
@@ -265,6 +350,12 @@ static StringRef getTypeName(tblgen::TypeDef &tblgenType) {
   return opName;
 }
 
+/// Returns the name of the attr without the dialect prefix.
+static StringRef getAttrName(tblgen::AttrDef &tblgenType) {
+  StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
+  return opName;
+}
+
 /// Extract an operation to IRDL.
 irdl::OperationOp createIRDLOperation(OpBuilder &builder,
                                       tblgen::Operator &tblgenOp) {
@@ -282,7 +373,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
     SmallVector<Value> operands;
     SmallVector<irdl::VariadicityAttr> variadicity;
     for (const NamedTypeConstraint &namedCons : namedCons) {
-      auto operand = createConstraint(consBuilder, namedCons.constraint);
+      auto operand = createTypeConstraint(consBuilder, namedCons.constraint);
       operands.push_back(operand);
 
       irdl::VariadicityAttr var;
@@ -304,6 +395,15 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
   auto [operands, operandVariadicity] = getValues(tblgenOp.getOperands());
   auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
 
+  SmallVector<Value> attributes;
+  SmallVector<Attribute> attrNames;
+  for (auto namedAttr : tblgenOp.getAttributes()) {
+    if (namedAttr.attr.isOptional())
+      continue;
+    attributes.push_back(createAttrConstraint(consBuilder, namedAttr.attr));
+    attrNames.push_back(StringAttr::get(ctx, namedAttr.name));
+  }
+
   // Create the operands and results operations.
   if (!operands.empty())
     consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
@@ -311,6 +411,9 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
   if (!results.empty())
     consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
                                         resultVariadicity);
+  if (!attributes.empty())
+    consBuilder.create<irdl::AttributesOp>(UnknownLoc::get(ctx), attributes,
+                                           ArrayAttr::get(ctx, attrNames));
 
   return op;
 }
@@ -328,6 +431,20 @@ irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
   return op;
 }
 
+irdl::AttributeOp createIRDLAttr(OpBuilder &builder,
+                                 tblgen::AttrDef &tblgenAttr) {
+  MLIRContext *ctx = builder.getContext();
+  StringRef attrName = getAttrName(tblgenAttr);
+  std::string combined = ("#" + attrName).str();
+
+  irdl::AttributeOp op = builder.create<irdl::AttributeOp>(
+      UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
+
+  op.getBody().emplaceBlock();
+
+  return op;
+}
+
 static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
   MLIRContext *ctx = builder.getContext();
   return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
@@ -358,6 +475,14 @@ static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper,
     createIRDLType(builder, tblgenType);
   }
 
+  for (const Record *attr :
+       recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef")) {
+    tblgen::AttrDef tblgenAttr(attr);
+    if (tblgenAttr.getDialect().getName() != selectedDialect)
+      continue;
+    createIRDLAttr(builder, tblgenAttr);
+  }
+
   for (const Record *def :
        recordKeeper.getAllDerivedDefinitionsIfDefined("Op")) {
     tblgen::Operator tblgenOp(def);

@llvmbot
Copy link
Member

llvmbot commented Sep 23, 2024

@llvm/pr-subscribers-mlir-core

Author: Alex Rice (alexarice)

Changes

Adds the ability to export attributes from the dialect and attributes of operations in the dialect

@math-fehr


Full diff: https://github.com/llvm/llvm-project/pull/109633.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/CommonAttrConstraints.td (+3)
  • (modified) mlir/test/tblgen-to-irdl/TestDialect.td (+17)
  • (modified) mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp (+132-7)
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 853fb318c76e71..de5f6797235e3c 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -178,6 +178,7 @@ class AnyAttrOf<list<Attr> allowedAttrs, string summary = "",
         summary)> {
     let returnType = cppType;
     let convertFromStorage = fromStorage;
+    list<Attr> allowedAttributes = allowedAttrs;
 }
 
 def LocationAttr : Attr<CPred<"::llvm::isa<::mlir::LocationAttr>($_self)">,
@@ -743,6 +744,8 @@ class ConfinedAttr<Attr attr, list<AttrConstraint> constraints> : Attr<
   let isOptional = attr.isOptional;
 
   let baseAttr = attr;
+
+  list<AttrConstraint> attrConstraints = constraints;
 }
 
 // An AttrConstraint that holds if all attr constraints specified in
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
index 4fea3d8576e9ab..1ba84a5d3683d4 100644
--- a/mlir/test/tblgen-to-irdl/TestDialect.td
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -13,6 +13,10 @@ class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
   let mnemonic = typeMnemonic;
 }
 
+class Test_Attr<string name, string attrMnemonic> : AttrDef<Test_Dialect, name> {
+  let mnemonic = attrMnemonic;
+}
+
 class Test_Op<string mnemonic, list<Trait> traits = []>
     : Op<Test_Dialect, mnemonic, traits>;
 
@@ -22,6 +26,8 @@ def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
 def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
 // CHECK: irdl.type @"!singleton_c"
 def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}
+// CHECK: irdl.attribute @"#test"
+def Test_TestAttr : Test_Attr<"Test", "test"> {}
 
 
 // Check that AllOfType is converted correctly.
@@ -45,6 +51,17 @@ def Test_AnyOp : Test_Op<"any"> {
 // CHECK-NEXT:    irdl.operands(%[[v0]])
 // CHECK-NEXT:  }
 
+// Check attributes are converted correctly.
+def Test_AttributesOp : Test_Op<"attributes"> {
+  let arguments = (ins I16Attr:$int_attr,
+                       Test_TestAttr:$test_attr);
+}
+// CHECK-LABEL: irdl.operation @attributes {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!builtin.integer"
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base @test::@"#test"
+// CHECK-NEXT:    irdl.attributes {"int_attr" = %[[v0]], "test_attr" = %[[v1]]}
+// CHECK-NEXT:  }
+
 // Check confined types are converted correctly.
 def Test_ConfinedOp : Test_Op<"confined"> {
   let arguments = (ins ConfinedType<AnyType, [CPred<"::llvm::isa<::mlir::TensorType>($_self)">]>:$tensor,
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index 45957bafc378e3..d0a3552fb123da 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -74,8 +74,14 @@ Value typeToConstraint(OpBuilder &builder, Type type) {
   return op.getOutput();
 }
 
-std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
+Value baseToConstraint(OpBuilder &builder, StringRef baseClass) {
+  MLIRContext *ctx = builder.getContext();
+  auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+                                         StringAttr::get(ctx, baseClass));
+  return op.getOutput();
+}
 
+std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
   if (predRec.isSubClassOf("I")) {
     auto width = predRec.getValueAsInt("bitwidth");
     return IntegerType::get(ctx, width, IntegerType::Signless);
@@ -164,12 +170,12 @@ std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
   return std::nullopt;
 }
 
-Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   MLIRContext *ctx = builder.getContext();
   const Record &predRec = constraint.getDef();
 
   if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
-    return createConstraint(builder, predRec.getValueAsDef("baseType"));
+    return createTypeConstraint(builder, predRec.getValueAsDef("baseType"));
 
   if (predRec.getName() == "AnyType") {
     auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
@@ -196,7 +202,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
     std::vector<Value> constraints;
     for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
       constraints.push_back(
-          createConstraint(builder, tblgen::Constraint(child)));
+          createTypeConstraint(builder, tblgen::Constraint(child)));
     }
     auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
     return op.getOutput();
@@ -206,7 +212,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
     std::vector<Value> constraints;
     for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
       constraints.push_back(
-          createConstraint(builder, tblgen::Constraint(child)));
+          createTypeConstraint(builder, tblgen::Constraint(child)));
     }
     auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
     return op.getOutput();
@@ -241,7 +247,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   // Confined type
   if (predRec.isSubClassOf("ConfinedType")) {
     std::vector<Value> constraints;
-    constraints.push_back(createConstraint(
+    constraints.push_back(createTypeConstraint(
         builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
     for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
       constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
@@ -253,6 +259,85 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   return createPredicate(builder, constraint.getPredicate());
 }
 
+Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+  MLIRContext *ctx = builder.getContext();
+  const Record &predRec = constraint.getDef();
+
+  if (predRec.isSubClassOf("DefaultValuedAttr") ||
+      predRec.isSubClassOf("DefaultValuedOptionalAttr") ||
+      predRec.isSubClassOf("OptionalAttr")) {
+    return createAttrConstraint(builder, predRec.getValueAsDef("baseAttr"));
+  }
+
+  if (predRec.isSubClassOf("ConfinedAttr")) {
+    std::vector<Value> constraints;
+    constraints.push_back(createAttrConstraint(
+        builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr"))));
+    for (Record *child : predRec.getValueAsListOfDefs("attrConstraints")) {
+      constraints.push_back(createPredicate(
+          builder, tblgen::Pred(child->getValueAsDef("predicate"))));
+    }
+    auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AnyAttrOf")) {
+    std::vector<Value> constraints;
+    for (Record *child : predRec.getValueAsListOfDefs("allowedAttributes")) {
+      constraints.push_back(
+          createAttrConstraint(builder, tblgen::Constraint(child)));
+    }
+    auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
+    return op.getOutput();
+  }
+
+  if (predRec.getName() == "AnyAttr") {
+    auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AnyIntegerAttrBase") ||
+      predRec.isSubClassOf("SignlessIntegerAttrBase") ||
+      predRec.isSubClassOf("SignedIntegerAttrBase") ||
+      predRec.isSubClassOf("UnsignedIntegerAttrBase") ||
+      predRec.isSubClassOf("BoolAttr")) {
+    return baseToConstraint(builder, "!builtin.integer");
+  }
+
+  if (predRec.isSubClassOf("FloatAttrBase")) {
+    return baseToConstraint(builder, "!builtin.float");
+  }
+
+  if (predRec.isSubClassOf("StringBasedAttr")) {
+    return baseToConstraint(builder, "!builtin.string");
+  }
+
+  if (predRec.getName() == "UnitAttr") {
+    auto op =
+        builder.create<irdl::IsOp>(UnknownLoc::get(ctx), UnitAttr::get(ctx));
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AttrDef")) {
+    auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
+    if (dialect == selectedDialect) {
+      std::string combined = ("#" + predRec.getValueAsString("mnemonic")).str();
+      SmallVector<FlatSymbolRefAttr> nested = {SymbolRefAttr::get(ctx, combined)
+
+      };
+      auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
+      auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
+      return op.getOutput();
+    }
+    std::string typeName = ("#" + predRec.getValueAsString("attrName")).str();
+    auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+                                           StringAttr::get(ctx, typeName));
+    return op.getOutput();
+  }
+
+  return createPredicate(builder, constraint.getPredicate());
+}
+
 /// Returns the name of the operation without the dialect prefix.
 static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
   StringRef opName = tblgenOp.getDef().getValueAsString("opName");
@@ -265,6 +350,12 @@ static StringRef getTypeName(tblgen::TypeDef &tblgenType) {
   return opName;
 }
 
+/// Returns the name of the attr without the dialect prefix.
+static StringRef getAttrName(tblgen::AttrDef &tblgenType) {
+  StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
+  return opName;
+}
+
 /// Extract an operation to IRDL.
 irdl::OperationOp createIRDLOperation(OpBuilder &builder,
                                       tblgen::Operator &tblgenOp) {
@@ -282,7 +373,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
     SmallVector<Value> operands;
     SmallVector<irdl::VariadicityAttr> variadicity;
     for (const NamedTypeConstraint &namedCons : namedCons) {
-      auto operand = createConstraint(consBuilder, namedCons.constraint);
+      auto operand = createTypeConstraint(consBuilder, namedCons.constraint);
       operands.push_back(operand);
 
       irdl::VariadicityAttr var;
@@ -304,6 +395,15 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
   auto [operands, operandVariadicity] = getValues(tblgenOp.getOperands());
   auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
 
+  SmallVector<Value> attributes;
+  SmallVector<Attribute> attrNames;
+  for (auto namedAttr : tblgenOp.getAttributes()) {
+    if (namedAttr.attr.isOptional())
+      continue;
+    attributes.push_back(createAttrConstraint(consBuilder, namedAttr.attr));
+    attrNames.push_back(StringAttr::get(ctx, namedAttr.name));
+  }
+
   // Create the operands and results operations.
   if (!operands.empty())
     consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
@@ -311,6 +411,9 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
   if (!results.empty())
     consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
                                         resultVariadicity);
+  if (!attributes.empty())
+    consBuilder.create<irdl::AttributesOp>(UnknownLoc::get(ctx), attributes,
+                                           ArrayAttr::get(ctx, attrNames));
 
   return op;
 }
@@ -328,6 +431,20 @@ irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
   return op;
 }
 
+irdl::AttributeOp createIRDLAttr(OpBuilder &builder,
+                                 tblgen::AttrDef &tblgenAttr) {
+  MLIRContext *ctx = builder.getContext();
+  StringRef attrName = getAttrName(tblgenAttr);
+  std::string combined = ("#" + attrName).str();
+
+  irdl::AttributeOp op = builder.create<irdl::AttributeOp>(
+      UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
+
+  op.getBody().emplaceBlock();
+
+  return op;
+}
+
 static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
   MLIRContext *ctx = builder.getContext();
   return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
@@ -358,6 +475,14 @@ static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper,
     createIRDLType(builder, tblgenType);
   }
 
+  for (const Record *attr :
+       recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef")) {
+    tblgen::AttrDef tblgenAttr(attr);
+    if (tblgenAttr.getDialect().getName() != selectedDialect)
+      continue;
+    createIRDLAttr(builder, tblgenAttr);
+  }
+
   for (const Record *def :
        recordKeeper.getAllDerivedDefinitionsIfDefined("Op")) {
     tblgen::Operator tblgenOp(def);

@llvmbot
Copy link
Member

llvmbot commented Sep 23, 2024

@llvm/pr-subscribers-mlir

Author: Alex Rice (alexarice)

Changes

Adds the ability to export attributes from the dialect and attributes of operations in the dialect

@math-fehr


Full diff: https://github.com/llvm/llvm-project/pull/109633.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/CommonAttrConstraints.td (+3)
  • (modified) mlir/test/tblgen-to-irdl/TestDialect.td (+17)
  • (modified) mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp (+132-7)
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 853fb318c76e71..de5f6797235e3c 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -178,6 +178,7 @@ class AnyAttrOf<list<Attr> allowedAttrs, string summary = "",
         summary)> {
     let returnType = cppType;
     let convertFromStorage = fromStorage;
+    list<Attr> allowedAttributes = allowedAttrs;
 }
 
 def LocationAttr : Attr<CPred<"::llvm::isa<::mlir::LocationAttr>($_self)">,
@@ -743,6 +744,8 @@ class ConfinedAttr<Attr attr, list<AttrConstraint> constraints> : Attr<
   let isOptional = attr.isOptional;
 
   let baseAttr = attr;
+
+  list<AttrConstraint> attrConstraints = constraints;
 }
 
 // An AttrConstraint that holds if all attr constraints specified in
diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td
index 4fea3d8576e9ab..1ba84a5d3683d4 100644
--- a/mlir/test/tblgen-to-irdl/TestDialect.td
+++ b/mlir/test/tblgen-to-irdl/TestDialect.td
@@ -13,6 +13,10 @@ class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
   let mnemonic = typeMnemonic;
 }
 
+class Test_Attr<string name, string attrMnemonic> : AttrDef<Test_Dialect, name> {
+  let mnemonic = attrMnemonic;
+}
+
 class Test_Op<string mnemonic, list<Trait> traits = []>
     : Op<Test_Dialect, mnemonic, traits>;
 
@@ -22,6 +26,8 @@ def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
 def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
 // CHECK: irdl.type @"!singleton_c"
 def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}
+// CHECK: irdl.attribute @"#test"
+def Test_TestAttr : Test_Attr<"Test", "test"> {}
 
 
 // Check that AllOfType is converted correctly.
@@ -45,6 +51,17 @@ def Test_AnyOp : Test_Op<"any"> {
 // CHECK-NEXT:    irdl.operands(%[[v0]])
 // CHECK-NEXT:  }
 
+// Check attributes are converted correctly.
+def Test_AttributesOp : Test_Op<"attributes"> {
+  let arguments = (ins I16Attr:$int_attr,
+                       Test_TestAttr:$test_attr);
+}
+// CHECK-LABEL: irdl.operation @attributes {
+// CHECK-NEXT:    %[[v0:[^ ]*]] = irdl.base "!builtin.integer"
+// CHECK-NEXT:    %[[v1:[^ ]*]] = irdl.base @test::@"#test"
+// CHECK-NEXT:    irdl.attributes {"int_attr" = %[[v0]], "test_attr" = %[[v1]]}
+// CHECK-NEXT:  }
+
 // Check confined types are converted correctly.
 def Test_ConfinedOp : Test_Op<"confined"> {
   let arguments = (ins ConfinedType<AnyType, [CPred<"::llvm::isa<::mlir::TensorType>($_self)">]>:$tensor,
diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
index 45957bafc378e3..d0a3552fb123da 100644
--- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
+++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
@@ -74,8 +74,14 @@ Value typeToConstraint(OpBuilder &builder, Type type) {
   return op.getOutput();
 }
 
-std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
+Value baseToConstraint(OpBuilder &builder, StringRef baseClass) {
+  MLIRContext *ctx = builder.getContext();
+  auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+                                         StringAttr::get(ctx, baseClass));
+  return op.getOutput();
+}
 
+std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
   if (predRec.isSubClassOf("I")) {
     auto width = predRec.getValueAsInt("bitwidth");
     return IntegerType::get(ctx, width, IntegerType::Signless);
@@ -164,12 +170,12 @@ std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
   return std::nullopt;
 }
 
-Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   MLIRContext *ctx = builder.getContext();
   const Record &predRec = constraint.getDef();
 
   if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
-    return createConstraint(builder, predRec.getValueAsDef("baseType"));
+    return createTypeConstraint(builder, predRec.getValueAsDef("baseType"));
 
   if (predRec.getName() == "AnyType") {
     auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
@@ -196,7 +202,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
     std::vector<Value> constraints;
     for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
       constraints.push_back(
-          createConstraint(builder, tblgen::Constraint(child)));
+          createTypeConstraint(builder, tblgen::Constraint(child)));
     }
     auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
     return op.getOutput();
@@ -206,7 +212,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
     std::vector<Value> constraints;
     for (const Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
       constraints.push_back(
-          createConstraint(builder, tblgen::Constraint(child)));
+          createTypeConstraint(builder, tblgen::Constraint(child)));
     }
     auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
     return op.getOutput();
@@ -241,7 +247,7 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   // Confined type
   if (predRec.isSubClassOf("ConfinedType")) {
     std::vector<Value> constraints;
-    constraints.push_back(createConstraint(
+    constraints.push_back(createTypeConstraint(
         builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
     for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
       constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
@@ -253,6 +259,85 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
   return createPredicate(builder, constraint.getPredicate());
 }
 
+Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
+  MLIRContext *ctx = builder.getContext();
+  const Record &predRec = constraint.getDef();
+
+  if (predRec.isSubClassOf("DefaultValuedAttr") ||
+      predRec.isSubClassOf("DefaultValuedOptionalAttr") ||
+      predRec.isSubClassOf("OptionalAttr")) {
+    return createAttrConstraint(builder, predRec.getValueAsDef("baseAttr"));
+  }
+
+  if (predRec.isSubClassOf("ConfinedAttr")) {
+    std::vector<Value> constraints;
+    constraints.push_back(createAttrConstraint(
+        builder, tblgen::Constraint(predRec.getValueAsDef("baseAttr"))));
+    for (Record *child : predRec.getValueAsListOfDefs("attrConstraints")) {
+      constraints.push_back(createPredicate(
+          builder, tblgen::Pred(child->getValueAsDef("predicate"))));
+    }
+    auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AnyAttrOf")) {
+    std::vector<Value> constraints;
+    for (Record *child : predRec.getValueAsListOfDefs("allowedAttributes")) {
+      constraints.push_back(
+          createAttrConstraint(builder, tblgen::Constraint(child)));
+    }
+    auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
+    return op.getOutput();
+  }
+
+  if (predRec.getName() == "AnyAttr") {
+    auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AnyIntegerAttrBase") ||
+      predRec.isSubClassOf("SignlessIntegerAttrBase") ||
+      predRec.isSubClassOf("SignedIntegerAttrBase") ||
+      predRec.isSubClassOf("UnsignedIntegerAttrBase") ||
+      predRec.isSubClassOf("BoolAttr")) {
+    return baseToConstraint(builder, "!builtin.integer");
+  }
+
+  if (predRec.isSubClassOf("FloatAttrBase")) {
+    return baseToConstraint(builder, "!builtin.float");
+  }
+
+  if (predRec.isSubClassOf("StringBasedAttr")) {
+    return baseToConstraint(builder, "!builtin.string");
+  }
+
+  if (predRec.getName() == "UnitAttr") {
+    auto op =
+        builder.create<irdl::IsOp>(UnknownLoc::get(ctx), UnitAttr::get(ctx));
+    return op.getOutput();
+  }
+
+  if (predRec.isSubClassOf("AttrDef")) {
+    auto dialect = predRec.getValueAsDef("dialect")->getValueAsString("name");
+    if (dialect == selectedDialect) {
+      std::string combined = ("#" + predRec.getValueAsString("mnemonic")).str();
+      SmallVector<FlatSymbolRefAttr> nested = {SymbolRefAttr::get(ctx, combined)
+
+      };
+      auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
+      auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
+      return op.getOutput();
+    }
+    std::string typeName = ("#" + predRec.getValueAsString("attrName")).str();
+    auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
+                                           StringAttr::get(ctx, typeName));
+    return op.getOutput();
+  }
+
+  return createPredicate(builder, constraint.getPredicate());
+}
+
 /// Returns the name of the operation without the dialect prefix.
 static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
   StringRef opName = tblgenOp.getDef().getValueAsString("opName");
@@ -265,6 +350,12 @@ static StringRef getTypeName(tblgen::TypeDef &tblgenType) {
   return opName;
 }
 
+/// Returns the name of the attr without the dialect prefix.
+static StringRef getAttrName(tblgen::AttrDef &tblgenType) {
+  StringRef opName = tblgenType.getDef()->getValueAsString("mnemonic");
+  return opName;
+}
+
 /// Extract an operation to IRDL.
 irdl::OperationOp createIRDLOperation(OpBuilder &builder,
                                       tblgen::Operator &tblgenOp) {
@@ -282,7 +373,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
     SmallVector<Value> operands;
     SmallVector<irdl::VariadicityAttr> variadicity;
     for (const NamedTypeConstraint &namedCons : namedCons) {
-      auto operand = createConstraint(consBuilder, namedCons.constraint);
+      auto operand = createTypeConstraint(consBuilder, namedCons.constraint);
       operands.push_back(operand);
 
       irdl::VariadicityAttr var;
@@ -304,6 +395,15 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
   auto [operands, operandVariadicity] = getValues(tblgenOp.getOperands());
   auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
 
+  SmallVector<Value> attributes;
+  SmallVector<Attribute> attrNames;
+  for (auto namedAttr : tblgenOp.getAttributes()) {
+    if (namedAttr.attr.isOptional())
+      continue;
+    attributes.push_back(createAttrConstraint(consBuilder, namedAttr.attr));
+    attrNames.push_back(StringAttr::get(ctx, namedAttr.name));
+  }
+
   // Create the operands and results operations.
   if (!operands.empty())
     consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
@@ -311,6 +411,9 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
   if (!results.empty())
     consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
                                         resultVariadicity);
+  if (!attributes.empty())
+    consBuilder.create<irdl::AttributesOp>(UnknownLoc::get(ctx), attributes,
+                                           ArrayAttr::get(ctx, attrNames));
 
   return op;
 }
@@ -328,6 +431,20 @@ irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
   return op;
 }
 
+irdl::AttributeOp createIRDLAttr(OpBuilder &builder,
+                                 tblgen::AttrDef &tblgenAttr) {
+  MLIRContext *ctx = builder.getContext();
+  StringRef attrName = getAttrName(tblgenAttr);
+  std::string combined = ("#" + attrName).str();
+
+  irdl::AttributeOp op = builder.create<irdl::AttributeOp>(
+      UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
+
+  op.getBody().emplaceBlock();
+
+  return op;
+}
+
 static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
   MLIRContext *ctx = builder.getContext();
   return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
@@ -358,6 +475,14 @@ static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper,
     createIRDLType(builder, tblgenType);
   }
 
+  for (const Record *attr :
+       recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef")) {
+    tblgen::AttrDef tblgenAttr(attr);
+    if (tblgenAttr.getDialect().getName() != selectedDialect)
+      continue;
+    createIRDLAttr(builder, tblgenAttr);
+  }
+
   for (const Record *def :
        recordKeeper.getAllDerivedDefinitionsIfDefined("Op")) {
     tblgen::Operator tblgenOp(def);

@math-fehr math-fehr self-requested a review September 27, 2024 16:44
Copy link
Contributor

@math-fehr math-fehr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

@math-fehr math-fehr merged commit 159470d into llvm:main Sep 27, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:ods mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants