Skip to content

Commit 5d3ae51

Browse files
authored
Reapply "[mlir][python] allow DenseIntElementsAttr for index type (#118947)" (#124804)
This reapplies #118947 and adapts to nanobind.
1 parent c836b89 commit 5d3ae51

File tree

6 files changed

+25
-2
lines changed

6 files changed

+25
-2
lines changed

mlir/include/mlir-c/BuiltinAttributes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,8 @@ MLIR_CAPI_EXPORTED int64_t
556556
mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos);
557557
MLIR_CAPI_EXPORTED uint64_t
558558
mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos);
559+
MLIR_CAPI_EXPORTED uint64_t
560+
mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos);
559561
MLIR_CAPI_EXPORTED float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr,
560562
intptr_t pos);
561563
MLIR_CAPI_EXPORTED double

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,13 +1372,19 @@ class PyDenseIntElementsAttribute
13721372

13731373
MlirType type = mlirAttributeGetType(*this);
13741374
type = mlirShapedTypeGetElementType(type);
1375-
assert(mlirTypeIsAInteger(type) &&
1376-
"expected integer element type in dense int elements attribute");
1375+
// Index type can also appear as a DenseIntElementsAttr and therefore can be
1376+
// casted to integer.
1377+
assert(mlirTypeIsAInteger(type) ||
1378+
mlirTypeIsAIndex(type) && "expected integer/index element type in "
1379+
"dense int elements attribute");
13771380
// Dispatch element extraction to an appropriate C function based on the
13781381
// elemental type of the attribute. nb::int_ is implicitly constructible
13791382
// from any C++ integral type and handles bitwidth correctly.
13801383
// TODO: consider caching the type properties in the constructor to avoid
13811384
// querying them on each element access.
1385+
if (mlirTypeIsAIndex(type)) {
1386+
return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
1387+
}
13821388
unsigned width = mlirIntegerTypeGetWidth(type);
13831389
bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
13841390
if (isUnsigned) {

mlir/lib/CAPI/IR/BuiltinAttributes.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,9 @@ int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
758758
uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
759759
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint64_t>()[pos];
760760
}
761+
uint64_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos) {
762+
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint64_t>()[pos];
763+
}
761764
float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
762765
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<float>()[pos];
763766
}

mlir/test/python/dialects/builtin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,7 @@ def testDenseElementsAttr():
246246
# CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
247247
print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
248248
# CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
249+
idx_values = np.arange(4, dtype=np.int64)
250+
idx_type = IndexType.get()
251+
print(DenseElementsAttr.get(idx_values, type=VectorType.get([4], idx_type)))
252+
# CHECK{LITERAL}: dense<[0, 1, 2, 3]> : vector<4xindex>

mlir/test/python/ir/array_attributes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,10 @@ def testGetDenseElementsIndex():
572572
print(arr)
573573
# CHECK: True
574574
print(arr.dtype == np.int64)
575+
array = np.array([1, 2, 3], dtype=np.int64)
576+
attr = DenseIntElementsAttr.get(array, type=VectorType.get([3], idx_type))
577+
# CHECK: [1, 2, 3]
578+
print(list(DenseIntElementsAttr(attr)))
575579

576580

577581
# CHECK-LABEL: TEST: testGetDenseResourceElementsAttr

mlir/test/python/ir/attributes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,10 @@ def testDenseIntAttr():
366366
# CHECK: i1
367367
print(ShapedType(a.type).element_type)
368368

369+
shape = Attribute.parse("dense<[0, 1, 2, 3]> : vector<4xindex>")
370+
# CHECK: attr: dense<[0, 1, 2, 3]>
371+
print("attr:", shape)
372+
369373

370374
@run
371375
def testDenseArrayGetItem():

0 commit comments

Comments
 (0)