Skip to content

[MLIR] Introduce new C bindings to differentiate between discardable and inherent attributes #66332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 26, 2023

Conversation

joker-eph
Copy link
Collaborator

This is part of the transition toward properly splitting the two groups. This only introduces new C APIs, the Python bindings are unaffected. No API is removed.

@joker-eph joker-eph requested a review from ftynse September 14, 2023 06:43
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Sep 14, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 14, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Changes This is part of the transition toward properly splitting the two groups. This only introduces new C APIs, the Python bindings are unaffected. No API is removed. -- Full diff: https://github.com//pull/66332.diff

4 Files Affected:

  • (modified) mlir/include/mlir-c/IR.h (+44)
  • (modified) mlir/include/mlir/IR/Operation.h (+17)
  • (modified) mlir/lib/CAPI/IR/IR.cpp (+41)
  • (modified) mlir/test/CAPI/ir.c (+11-19)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index b5c6a3094bc67df..f15e7bbd89fb29d 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -552,25 +552,69 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op);
 MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op,
                                                        intptr_t pos);
 
+/// Returns an inherent attribute attached to the operation given its name.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirOperationGetInherentAttributeByName(MlirOperation op, MlirStringRef name);
+
+/// Sets an inherent attribute by name, replacing the existing if it exists.
+/// This has no effect if "name" does not match an inherent attribute.
+MLIR_CAPI_EXPORTED void
+mlirOperationSetInherentAttributeByName(MlirOperation op, MlirStringRef name,
+                                        MlirAttribute attr);
+
+/// Returns the number of discardable attributes attached to the operation.
+MLIR_CAPI_EXPORTED intptr_t
+mlirOperationGetNumDiscardableAttributes(MlirOperation op);
+
+/// Return `pos`-th discardable attribute of the operation.
+MLIR_CAPI_EXPORTED MlirNamedAttribute
+mlirOperationGetDiscardableAttribute(MlirOperation op, intptr_t pos);
+
+/// Returns a discardable attribute attached to the operation given its name.
+MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetDiscardableAttributeByName(
+    MlirOperation op, MlirStringRef name);
+
+/// Sets a discardable attribute by name, replacing the existing if it exists or
+/// adding a new one otherwise.
+MLIR_CAPI_EXPORTED void
+mlirOperationSetDiscardableAttributeByName(MlirOperation op, MlirStringRef name,
+                                           MlirAttribute attr);
+
+/// Removes a discardable attribute by name. Returns false if the attribute was
+/// not found and true if removed.
+MLIR_CAPI_EXPORTED bool
+mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
+                                              MlirStringRef name);
+
 /// Returns the number of attributes attached to the operation.
+/// Deprecated, please use `mlirOperationGetNumInherentAttributes` or
+/// `mlirOperationGetNumDiscardableAttributes`.
 MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op);
 
 /// Return `pos`-th attribute of the operation.
+/// Deprecated, please use `mlirOperationGetInherentAttribute` or
+/// `mlirOperationGetDiscardableAttribute`.
 MLIR_CAPI_EXPORTED MlirNamedAttribute
 mlirOperationGetAttribute(MlirOperation op, intptr_t pos);
 
 /// Returns an attribute attached to the operation given its name.
+/// Deprecated, please use `mlirOperationGetInherentAttributeByName` or
+/// `mlirOperationGetDiscardableAttributeByName`.
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name);
 
 /// Sets an attribute by name, replacing the existing if it exists or
 /// adding a new one otherwise.
+/// Deprecated, please use `mlirOperationSetInherentAttributeByName` or
+/// `mlirOperationSetDiscardableAttributeByName`.
 MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op,
                                                         MlirStringRef name,
                                                         MlirAttribute attr);
 
 /// Removes an attribute by name. Returns false if the attribute was not found
 /// and true if removed.
+/// Deprecated, please use `mlirOperationRemoveInherentAttributeByName` or
+/// `mlirOperationRemoveDiscardableAttributeByName`.
 MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op,
                                                            MlirStringRef name);
 
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index b815eaf8899d6fc..b42ec231f1cbb20 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -457,6 +457,23 @@ class alignas(8) Operation final
     if (attributes.set(name, value) != value)
       attrs = attributes.getDictionary(getContext());
   }
+  void setDiscardableAttr(StringRef name, Attribute value) {
+    setDiscardableAttr(StringAttr::get(getContext(), name), value);
+  }
+
+  /// Remove the discardable attribute with the specified name if it exists. Return the
+  /// attribute that was erased, or nullptr if there was no attribute with such
+  /// name.
+  Attribute removeDiscardableAttr(StringAttr name) {
+    NamedAttrList attributes(attrs);
+    Attribute removedAttr = attributes.erase(name);
+    if (removedAttr)
+      attrs = attributes.getDictionary(getContext());
+    return removedAttr;
+  }
+  Attribute removeDiscardableAttr(StringRef name) {
+    return removeDiscardableAttr(StringAttr::get(getContext(), name));
+  }
 
   /// Return all of the discardable attributes on this operation.
   ArrayRef<NamedAttribute> getDiscardableAttrs() { return attrs.getValue(); }
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index ef234a912490eea..2e17e5b920e6268 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -550,6 +550,47 @@ MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
   return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
 }
 
+MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op,
+                                                      MlirStringRef name) {
+  std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name));
+  if (attr.has_value())
+    return wrap(*attr);
+  return {};
+}
+
+void mlirOperationSetInherentAttributeByName(MlirOperation op,
+                                             MlirStringRef name,
+                                             MlirAttribute attr) {
+  unwrap(op)->setInherentAttr(
+      StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr));
+}
+
+intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) {
+  return static_cast<intptr_t>(unwrap(op)->getDiscardableAttrs().size());
+}
+
+MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op,
+                                                        intptr_t pos) {
+  NamedAttribute attr = unwrap(op)->getDiscardableAttrs()[pos];
+  return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
+}
+
+MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op,
+                                                         MlirStringRef name) {
+  return wrap(unwrap(op)->getDiscardableAttr(unwrap(name)));
+}
+
+void mlirOperationSetDiscardableAttributeByName(MlirOperation op,
+                                                MlirStringRef name,
+                                                MlirAttribute attr) {
+  unwrap(op)->setDiscardableAttr(unwrap(name), unwrap(attr));
+}
+
+bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
+                                                   MlirStringRef name) {
+  return !!unwrap(op)->removeDiscardableAttr(unwrap(name));
+}
+
 intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
   return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
 }
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 5d78daa296501f4..88e701315d618e2 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -397,23 +397,15 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   // CHECK: Terminator: func.return
 
   // Get the attribute by index.
-  MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
-  fprintf(stderr, "Get attr 0: ");
-  mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
+  MlirAttribute valueAttr0 = mlirOperationGetInherentAttributeByName(operation, mlirStringRefCreateFromCString("value"));
+  fprintf(stderr, "Get attr \"value\": ");
+  mlirAttributePrint(valueAttr0, printToStderr, NULL);
   fprintf(stderr, "\n");
-  // CHECK: Get attr 0: 0 : index
-
-  // Now re-get the attribute by name.
-  MlirAttribute attr0ByName = mlirOperationGetAttributeByName(
-      operation, mlirIdentifierStr(namedAttr0.name));
-  fprintf(stderr, "Get attr 0 by name: ");
-  mlirAttributePrint(attr0ByName, printToStderr, NULL);
-  fprintf(stderr, "\n");
-  // CHECK: Get attr 0 by name: 0 : index
+  // CHECK: Get attr "value": 0 : index
 
   // Get a non-existing attribute and assert that it is null (sanity).
   fprintf(stderr, "does_not_exist is null: %d\n",
-          mlirAttributeIsNull(mlirOperationGetAttributeByName(
+          mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
               operation, mlirStringRefCreateFromCString("does_not_exist"))));
   // CHECK: does_not_exist is null: 1
 
@@ -432,8 +424,8 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   fprintf(stderr, "\n");
   // CHECK: Result 0 type: index
 
-  // Set a custom attribute.
-  mlirOperationSetAttributeByName(operation,
+  // Set a discardable attribute.
+  mlirOperationSetDiscardableAttributeByName(operation,
                                   mlirStringRefCreateFromCString("custom_attr"),
                                   mlirBoolAttrGet(ctx, 1));
   fprintf(stderr, "Op with set attr: ");
@@ -443,13 +435,13 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
 
   // Remove the attribute.
   fprintf(stderr, "Remove attr: %d\n",
-          mlirOperationRemoveAttributeByName(
+          mlirOperationRemoveDiscardableAttributeByName(
               operation, mlirStringRefCreateFromCString("custom_attr")));
   fprintf(stderr, "Remove attr again: %d\n",
-          mlirOperationRemoveAttributeByName(
+          mlirOperationRemoveDiscardableAttributeByName(
               operation, mlirStringRefCreateFromCString("custom_attr")));
   fprintf(stderr, "Removed attr is null: %d\n",
-          mlirAttributeIsNull(mlirOperationGetAttributeByName(
+          mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
               operation, mlirStringRefCreateFromCString("custom_attr"))));
   // CHECK: Remove attr: 1
   // CHECK: Remove attr again: 0
@@ -458,7 +450,7 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   // Add a large attribute to verify printing flags.
   int64_t eltsShape[] = {4};
   int32_t eltsData[] = {1, 2, 3, 4};
-  mlirOperationSetAttributeByName(
+  mlirOperationSetDiscardableAttributeByName(
       operation, mlirStringRefCreateFromCString("elts"),
       mlirDenseElementsAttrInt32Get(
           mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),

@@ -552,25 +552,69 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op);
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op,
intptr_t pos);

/// Returns an inherent attribute attached to the operation given its name.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: what happens when there is no attribute? (can we differentiate between an optional inherent attribute being absent and the name not corresponding to an inherent attribute?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, I wasn't sure if it was useful for the C API, the alternative may be to add a bool hasInherentAttr(name) API that the use could check on top of this one to distinguish?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added this API, but let me know if you have a better idea.

…and inherent attributes

This is part of the transition toward properly splitting the two groups.
This only introduces new C APIs, the Python bindings are unaffected.
No API is removed.
@joker-eph joker-eph merged commit 7675f54 into llvm:main Sep 26, 2023
@joker-eph joker-eph deleted the c-api-attributes branch September 26, 2023 08:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants