Skip to content

Commit eb28bbc

Browse files
authored
Merge pull request #126 from Xilinx/christopher.FXML-4134_python_bindings_index_type
[FIX] Python Bindings DenseIntElementsAttr as Vector<Index> type
2 parents 7ed5dd2 + 37e4ced commit eb28bbc

File tree

6 files changed

+25
-2
lines changed

6 files changed

+25
-2
lines changed

mlir/include/mlir-c/BuiltinAttributes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,8 @@ MLIR_CAPI_EXPORTED int64_t
543543
mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos);
544544
MLIR_CAPI_EXPORTED uint64_t
545545
mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos);
546+
MLIR_CAPI_EXPORTED size_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr,
547+
intptr_t pos);
546548
MLIR_CAPI_EXPORTED float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr,
547549
intptr_t pos);
548550
MLIR_CAPI_EXPORTED double

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -973,13 +973,19 @@ class PyDenseIntElementsAttribute
973973

974974
MlirType type = mlirAttributeGetType(*this);
975975
type = mlirShapedTypeGetElementType(type);
976-
assert(mlirTypeIsAInteger(type) &&
977-
"expected integer element type in dense int elements attribute");
976+
// Index type can also appear as a DenseIntElementsAttr and therefore can be
977+
// casted to integer.
978+
assert(mlirTypeIsAInteger(type) ||
979+
mlirTypeIsAIndex(type) && "expected integer/index element type in "
980+
"dense int elements attribute");
978981
// Dispatch element extraction to an appropriate C function based on the
979982
// elemental type of the attribute. py::int_ is implicitly constructible
980983
// from any C++ integral type and handles bitwidth correctly.
981984
// TODO: consider caching the type properties in the constructor to avoid
982985
// querying them on each element access.
986+
if (mlirTypeIsAIndex(type)) {
987+
return mlirDenseElementsAttrGetIndexValue(*this, pos);
988+
}
983989
unsigned width = mlirIntegerTypeGetWidth(type);
984990
bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
985991
if (isUnsigned) {

mlir/lib/CAPI/IR/BuiltinAttributes.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,9 @@ int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
745745
uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
746746
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint64_t>()[pos];
747747
}
748+
size_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos) {
749+
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<size_t>()[pos];
750+
}
748751
float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
749752
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<float>()[pos];
750753
}

mlir/test/python/dialects/builtin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,7 @@ def testDenseElementsAttr():
246246
# CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
247247
print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
248248
# CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
249+
idx_values = np.arange(4, dtype=np.int64)
250+
idx_type = IndexType.get()
251+
print(DenseElementsAttr.get(idx_values, type=VectorType.get([4], idx_type)))
252+
# CHECK{LITERAL}: dense<[0, 1, 2, 3]> : vector<4xindex>

mlir/test/python/ir/array_attributes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,10 @@ def testGetDenseElementsIndex():
418418
print(arr)
419419
# CHECK: True
420420
print(arr.dtype == np.int64)
421+
array = np.array([1, 2, 3], dtype=np.int64)
422+
attr = DenseIntElementsAttr.get(array, type=VectorType.get([3], idx_type))
423+
# CHECK: [1, 2, 3]
424+
print(list(DenseIntElementsAttr(attr)))
421425

422426

423427
# CHECK-LABEL: TEST: testGetDenseResourceElementsAttr

mlir/test/python/ir/attributes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,10 @@ def testDenseIntAttr():
348348
# CHECK: i1
349349
print(ShapedType(a.type).element_type)
350350

351+
shape = Attribute.parse("dense<[0, 1, 2, 3]> : vector<4xindex>")
352+
# CHECK: attr: dense<[0, 1, 2, 3]>
353+
print("attr:", shape)
354+
351355

352356
@run
353357
def testDenseArrayGetItem():

0 commit comments

Comments
 (0)