Skip to content

Reapply "[mlir][python] allow DenseIntElementsAttr for index type (#118947)" #124804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 29, 2025

Conversation

mgehre-amd
Copy link
Contributor

This reapplies #118947 and adapts to nanobind.

@mgehre-amd mgehre-amd marked this pull request as ready for review January 29, 2025 07:35
@llvmbot llvmbot added the mlir label Jan 29, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 29, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Gehre (mgehre-amd)

Changes

This reapplies #118947 and adapts to nanobind.


Full diff: https://github.com/llvm/llvm-project/pull/124804.diff

6 Files Affected:

  • (modified) mlir/include/mlir-c/BuiltinAttributes.h (+2)
  • (modified) mlir/lib/Bindings/Python/IRAttributes.cpp (+8-2)
  • (modified) mlir/lib/CAPI/IR/BuiltinAttributes.cpp (+3)
  • (modified) mlir/test/python/dialects/builtin.py (+4)
  • (modified) mlir/test/python/ir/array_attributes.py (+4)
  • (modified) mlir/test/python/ir/attributes.py (+4)
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 7c8c84e55b962f..1d0edf9ea809d2 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -556,6 +556,8 @@ MLIR_CAPI_EXPORTED int64_t
 mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos);
 MLIR_CAPI_EXPORTED uint64_t
 mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos);
+MLIR_CAPI_EXPORTED uint64_t
+mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos);
 MLIR_CAPI_EXPORTED float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr,
                                                             intptr_t pos);
 MLIR_CAPI_EXPORTED double
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 7bc21a31c3c84c..d3ceb3d435c1c0 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1372,13 +1372,19 @@ class PyDenseIntElementsAttribute
 
     MlirType type = mlirAttributeGetType(*this);
     type = mlirShapedTypeGetElementType(type);
-    assert(mlirTypeIsAInteger(type) &&
-           "expected integer element type in dense int elements attribute");
+    // Index type can also appear as a DenseIntElementsAttr and therefore can be
+    // casted to integer.
+    assert(mlirTypeIsAInteger(type) ||
+           mlirTypeIsAIndex(type) && "expected integer/index element type in "
+                                     "dense int elements attribute");
     // Dispatch element extraction to an appropriate C function based on the
     // elemental type of the attribute. nb::int_ is implicitly constructible
     // from any C++ integral type and handles bitwidth correctly.
     // TODO: consider caching the type properties in the constructor to avoid
     // querying them on each element access.
+    if (mlirTypeIsAIndex(type)) {
+      return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
+    }
     unsigned width = mlirIntegerTypeGetWidth(type);
     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
     if (isUnsigned) {
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 11d1ade552f5a2..8d57ab6b59e79c 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -758,6 +758,9 @@ int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
 uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint64_t>()[pos];
 }
+uint64_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos) {
+  return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint64_t>()[pos];
+}
 float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<float>()[pos];
 }
diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py
index 18ebba61e7fea8..973a0eaeca2cdb 100644
--- a/mlir/test/python/dialects/builtin.py
+++ b/mlir/test/python/dialects/builtin.py
@@ -246,3 +246,7 @@ def testDenseElementsAttr():
         # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
         print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
         # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
+        idx_values = np.arange(4, dtype=np.int64)
+        idx_type = IndexType.get()
+        print(DenseElementsAttr.get(idx_values, type=VectorType.get([4], idx_type)))
+        # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : vector<4xindex>
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 256a69a939658d..ef1d835fc64012 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -572,6 +572,10 @@ def testGetDenseElementsIndex():
         print(arr)
         # CHECK: True
         print(arr.dtype == np.int64)
+        array = np.array([1, 2, 3], dtype=np.int64)
+        attr = DenseIntElementsAttr.get(array, type=VectorType.get([3], idx_type))
+        # CHECK: [1, 2, 3]
+        print(list(DenseIntElementsAttr(attr)))
 
 
 # CHECK-LABEL: TEST: testGetDenseResourceElementsAttr
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 00c3e1b4decdb7..2f3c4460d3f590 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -366,6 +366,10 @@ def testDenseIntAttr():
         # CHECK: i1
         print(ShapedType(a.type).element_type)
 
+        shape = Attribute.parse("dense<[0, 1, 2, 3]> : vector<4xindex>")
+        # CHECK: attr: dense<[0, 1, 2, 3]>
+        print("attr:", shape)
+
 
 @run
 def testDenseArrayGetItem():

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mgehre-amd mgehre-amd merged commit 5d3ae51 into llvm:main Jan 29, 2025
12 checks passed
@mgehre-amd mgehre-amd deleted the matthias.reland_python_index branch January 29, 2025 08:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants