Skip to content

[MLIR] Introduce new C bindings to differentiate between discardable and inherent attributes #66332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,25 +552,77 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op);
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op,
intptr_t pos);

/// Returns true if this operation defines an inherent attribute with this name.
/// Note: the attribute can be optional, so
/// `mlirOperationGetInherentAttributeByName` can still return a null attribute.
MLIR_CAPI_EXPORTED bool
mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name);

/// Returns an inherent attribute attached to the operation given its name.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: what happens when there is no attribute? (can we differentiate between an optional inherent attribute being absent and the name not corresponding to an inherent attribute?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, I wasn't sure if it was useful for the C API, the alternative may be to add a bool hasInherentAttr(name) API that the use could check on top of this one to distinguish?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added this API, but let me know if you have a better idea.

MLIR_CAPI_EXPORTED MlirAttribute
mlirOperationGetInherentAttributeByName(MlirOperation op, MlirStringRef name);

/// Sets an inherent attribute by name, replacing the existing if it exists.
/// This has no effect if "name" does not match an inherent attribute.
MLIR_CAPI_EXPORTED void
mlirOperationSetInherentAttributeByName(MlirOperation op, MlirStringRef name,
MlirAttribute attr);

/// Returns the number of discardable attributes attached to the operation.
MLIR_CAPI_EXPORTED intptr_t
mlirOperationGetNumDiscardableAttributes(MlirOperation op);

/// Return `pos`-th discardable attribute of the operation.
MLIR_CAPI_EXPORTED MlirNamedAttribute
mlirOperationGetDiscardableAttribute(MlirOperation op, intptr_t pos);

/// Returns a discardable attribute attached to the operation given its name.
MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetDiscardableAttributeByName(
MlirOperation op, MlirStringRef name);

/// Sets a discardable attribute by name, replacing the existing if it exists or
/// adding a new one otherwise. The new `attr` Attribute is not allowed to be
/// null, use `mlirOperationRemoveDiscardableAttributeByName` to remove an
/// Attribute instead.
MLIR_CAPI_EXPORTED void
mlirOperationSetDiscardableAttributeByName(MlirOperation op, MlirStringRef name,
MlirAttribute attr);

/// Removes a discardable attribute by name. Returns false if the attribute was
/// not found and true if removed.
MLIR_CAPI_EXPORTED bool
mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
MlirStringRef name);

/// Returns the number of attributes attached to the operation.
/// Deprecated, please use `mlirOperationGetNumInherentAttributes` or
/// `mlirOperationGetNumDiscardableAttributes`.
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op);

/// Return `pos`-th attribute of the operation.
/// Deprecated, please use `mlirOperationGetInherentAttribute` or
/// `mlirOperationGetDiscardableAttribute`.
MLIR_CAPI_EXPORTED MlirNamedAttribute
mlirOperationGetAttribute(MlirOperation op, intptr_t pos);

/// Returns an attribute attached to the operation given its name.
/// Deprecated, please use `mlirOperationGetInherentAttributeByName` or
/// `mlirOperationGetDiscardableAttributeByName`.
MLIR_CAPI_EXPORTED MlirAttribute
mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name);

/// Sets an attribute by name, replacing the existing if it exists or
/// adding a new one otherwise.
/// Deprecated, please use `mlirOperationSetInherentAttributeByName` or
/// `mlirOperationSetDiscardableAttributeByName`.
MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op,
MlirStringRef name,
MlirAttribute attr);

/// Removes an attribute by name. Returns false if the attribute was not found
/// and true if removed.
/// Deprecated, please use `mlirOperationRemoveInherentAttributeByName` or
/// `mlirOperationRemoveDiscardableAttributeByName`.
MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op,
MlirStringRef name);

Expand Down
17 changes: 17 additions & 0 deletions mlir/include/mlir/IR/Operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,23 @@ class alignas(8) Operation final
if (attributes.set(name, value) != value)
attrs = attributes.getDictionary(getContext());
}
void setDiscardableAttr(StringRef name, Attribute value) {
setDiscardableAttr(StringAttr::get(getContext(), name), value);
}

/// Remove the discardable attribute with the specified name if it exists.
/// Return the attribute that was erased, or nullptr if there was no attribute
/// with such name.
Attribute removeDiscardableAttr(StringAttr name) {
NamedAttrList attributes(attrs);
Attribute removedAttr = attributes.erase(name);
if (removedAttr)
attrs = attributes.getDictionary(getContext());
return removedAttr;
}
Attribute removeDiscardableAttr(StringRef name) {
return removeDiscardableAttr(StringAttr::get(getContext(), name));
}

/// Return all of the discardable attributes on this operation.
ArrayRef<NamedAttribute> getDiscardableAttrs() { return attrs.getValue(); }
Expand Down
47 changes: 47 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,53 @@ MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
}

MLIR_CAPI_EXPORTED bool
mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) {
std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name));
return attr.has_value();
}

MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op,
MlirStringRef name) {
std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name));
if (attr.has_value())
return wrap(*attr);
return {};
}

void mlirOperationSetInherentAttributeByName(MlirOperation op,
MlirStringRef name,
MlirAttribute attr) {
unwrap(op)->setInherentAttr(
StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr));
}

intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getDiscardableAttrs().size());
}

MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op,
intptr_t pos) {
NamedAttribute attr = unwrap(op)->getDiscardableAttrs()[pos];
return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
}

MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op,
MlirStringRef name) {
return wrap(unwrap(op)->getDiscardableAttr(unwrap(name)));
}

void mlirOperationSetDiscardableAttributeByName(MlirOperation op,
MlirStringRef name,
MlirAttribute attr) {
unwrap(op)->setDiscardableAttr(unwrap(name), unwrap(attr));
}

bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
MlirStringRef name) {
return !!unwrap(op)->removeDiscardableAttr(unwrap(name));
}

intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
}
Expand Down
43 changes: 21 additions & 22 deletions mlir/test/CAPI/ir.c
Original file line number Diff line number Diff line change
Expand Up @@ -407,24 +407,23 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
fprintf(stderr, "\n");
// CHECK: Terminator: func.return

// Get the attribute by index.
MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
fprintf(stderr, "Get attr 0: ");
mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
// Get the attribute by name.
bool hasValueAttr = mlirOperationHasInherentAttributeByName(
operation, mlirStringRefCreateFromCString("value"));
if (hasValueAttr)
// CHECK: Has attr "value"
fprintf(stderr, "Has attr \"value\"");

MlirAttribute valueAttr0 = mlirOperationGetInherentAttributeByName(
operation, mlirStringRefCreateFromCString("value"));
fprintf(stderr, "Get attr \"value\": ");
mlirAttributePrint(valueAttr0, printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: Get attr 0: 0 : index

// Now re-get the attribute by name.
MlirAttribute attr0ByName = mlirOperationGetAttributeByName(
operation, mlirIdentifierStr(namedAttr0.name));
fprintf(stderr, "Get attr 0 by name: ");
mlirAttributePrint(attr0ByName, printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: Get attr 0 by name: 0 : index
// CHECK: Get attr "value": 0 : index

// Get a non-existing attribute and assert that it is null (sanity).
fprintf(stderr, "does_not_exist is null: %d\n",
mlirAttributeIsNull(mlirOperationGetAttributeByName(
mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("does_not_exist"))));
// CHECK: does_not_exist is null: 1

Expand All @@ -443,24 +442,24 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
fprintf(stderr, "\n");
// CHECK: Result 0 type: index

// Set a custom attribute.
mlirOperationSetAttributeByName(operation,
mlirStringRefCreateFromCString("custom_attr"),
mlirBoolAttrGet(ctx, 1));
// Set a discardable attribute.
mlirOperationSetDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("custom_attr"),
mlirBoolAttrGet(ctx, 1));
fprintf(stderr, "Op with set attr: ");
mlirOperationPrint(operation, printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: Op with set attr: {{.*}} {custom_attr = true}

// Remove the attribute.
fprintf(stderr, "Remove attr: %d\n",
mlirOperationRemoveAttributeByName(
mlirOperationRemoveDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("custom_attr")));
fprintf(stderr, "Remove attr again: %d\n",
mlirOperationRemoveAttributeByName(
mlirOperationRemoveDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("custom_attr")));
fprintf(stderr, "Removed attr is null: %d\n",
mlirAttributeIsNull(mlirOperationGetAttributeByName(
mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("custom_attr"))));
// CHECK: Remove attr: 1
// CHECK: Remove attr again: 0
Expand All @@ -469,7 +468,7 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
// Add a large attribute to verify printing flags.
int64_t eltsShape[] = {4};
int32_t eltsData[] = {1, 2, 3, 4};
mlirOperationSetAttributeByName(
mlirOperationSetDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("elts"),
mlirDenseElementsAttrInt32Get(
mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),
Expand Down