-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] support scalable vectors in python bindings #71050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Oleksandr "Alex" Zinenko (ftynse) ChangesThe scalable dimension functionality was added to the vector type after the bindings for it were defined, without the bindings being ever updated. Fix that. Full diff: https://github.com/llvm/llvm-project/pull/71050.diff 5 Files Affected:
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index a6d8e10efbde923..1fd5691f41eec35 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -271,6 +271,32 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
const int64_t *shape,
MlirType elementType);
+/// Creates a scalable vector type with the shape identified by its rank and
+/// dimensions. A subset of dimensions may be marked as scalable via the
+/// corresponding flag list, which is expected to have as many entries as the
+/// rank of the vector. The vector is created in the same context as the element
+/// type.
+MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalable(intptr_t rank,
+ const int64_t *shape,
+ const bool *scalable,
+ MlirType elementType);
+
+/// Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType
+/// on illegal arguments, emitting appropriate diagnostics.
+MLIR_CAPI_EXPORTED
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+ const int64_t *shape,
+ const bool *scalable,
+ MlirType elementType);
+
+/// Checks whether the given vector type is scalable, i.e., has at least one
+/// scalable dimension.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type);
+
+/// Checks whether the "dim"-th dimension of the given vector is scalable.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type,
+ intptr_t dim);
+
//===----------------------------------------------------------------------===//
// Ranked / Unranked Tensor type.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index a7ccfbea542f5c7..e145e05ad9b4f19 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -12,6 +12,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "llvm/ADT/ScopeExit.h"
#include <optional>
namespace py = pybind11;
@@ -463,18 +464,47 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
static void bindDerived(ClassTy &c) {
c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
- elementType);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyVectorType(elementType.getContext(), t);
- },
- py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
- "Create a vector type");
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::optional<py::list> scalable, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw py::value_error("Expected len(scalable) == len(shape).");
+
+ // Vector-of-bool may be using bit packing, so we cannot access its
+ // data directly. Explicitly create an array-of-bool instead.
+ bool *scalableData =
+ static_cast<bool *>(malloc(sizeof(bool) * scalable->size()));
+ auto deleter = llvm::make_scope_exit([&] { free(scalableData); });
+ auto range = llvm::map_range(
+ *scalable, [](const py::handle &h) { return h.cast<bool>(); });
+ llvm::copy(range, scalableData);
+ type = mlirVectorTypeGetScalableChecked(
+ loc, shape.size(), shape.data(), scalableData, elementType);
+ } else {
+ type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+ elementType);
+ }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
+ },
+ py::arg("shape"), py::arg("elementType"), py::kw_only(),
+ py::arg("scalable") = py::none(), py::arg("loc") = py::none(),
+ "Create a vector type")
+ .def_property_readonly(
+ "scalable",
+ [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+ .def_property_readonly("scalable_dims", [](MlirType self) {
+ std::vector<bool> scalableDims;
+ size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+ scalableDims.reserve(rank);
+ for (size_t i = 0; i < rank; ++i)
+ scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+ return scalableDims;
+ });
}
};
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 50266b4b5233235..6e645188dac8616 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -281,6 +281,31 @@ MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
unwrap(elementType)));
}
+MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
+ const bool *scalable, MlirType elementType) {
+ return wrap(VectorType::get(
+ llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+ const int64_t *shape,
+ const bool *scalable,
+ MlirType elementType) {
+ return wrap(VectorType::getChecked(
+ unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+ unwrap(elementType),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+bool mlirVectorTypeIsScalable(MlirType type) {
+ return unwrap(type).cast<VectorType>().isScalable();
+}
+
+bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
+ return unwrap(type).cast<VectorType>().getScalableDims()[dim];
+}
+
//===----------------------------------------------------------------------===//
// Ranked / Unranked tensor type.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index c6425f80a8bce9c..3a2bdb9bfc93334 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -743,13 +743,27 @@ static int printBuiltinTypes(MlirContext ctx) {
fprintf(stderr, "\n");
// CHECK: vector<2x3xf32>
+ // Scalable vector type.
+ bool scalable[] = {false, true};
+ MlirType scalableVector = mlirVectorTypeGetScalable(
+ sizeof(shape) / sizeof(int64_t), shape, scalable, f32);
+ if (!mlirTypeIsAVector(scalableVector))
+ return 16;
+ if (!mlirVectorTypeIsScalable(scalableVector) ||
+ mlirVectorTypeIsDimScalable(scalableVector, 0) ||
+ !mlirVectorTypeIsDimScalable(scalableVector, 1))
+ return 17;
+ mlirTypeDump(scalableVector);
+ fprintf(stderr, "\n");
+ // CHECK: vector<2x[3]xf32>
+
// Ranked tensor type.
MlirType rankedTensor = mlirRankedTensorTypeGet(
sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
if (!mlirTypeIsATensor(rankedTensor) ||
!mlirTypeIsARankedTensor(rankedTensor) ||
!mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
- return 16;
+ return 18;
mlirTypeDump(rankedTensor);
fprintf(stderr, "\n");
// CHECK: tensor<2x3xf32>
@@ -759,7 +773,7 @@ static int printBuiltinTypes(MlirContext ctx) {
if (!mlirTypeIsATensor(unrankedTensor) ||
!mlirTypeIsAUnrankedTensor(unrankedTensor) ||
mlirShapedTypeHasRank(unrankedTensor))
- return 17;
+ return 19;
mlirTypeDump(unrankedTensor);
fprintf(stderr, "\n");
// CHECK: tensor<*xf32>
@@ -770,7 +784,7 @@ static int printBuiltinTypes(MlirContext ctx) {
f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
if (!mlirTypeIsAMemRef(memRef) ||
!mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
- return 18;
+ return 20;
mlirTypeDump(memRef);
fprintf(stderr, "\n");
// CHECK: memref<2x3xf32, 2>
@@ -782,7 +796,7 @@ static int printBuiltinTypes(MlirContext ctx) {
mlirTypeIsAMemRef(unrankedMemRef) ||
!mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
memSpace4))
- return 19;
+ return 21;
mlirTypeDump(unrankedMemRef);
fprintf(stderr, "\n");
// CHECK: memref<*xf32, 4>
@@ -793,7 +807,7 @@ static int printBuiltinTypes(MlirContext ctx) {
if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
!mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
!mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
- return 20;
+ return 22;
mlirTypeDump(tuple);
fprintf(stderr, "\n");
// CHECK: tuple<memref<*xf32, 4>, f32>
@@ -805,16 +819,16 @@ static int printBuiltinTypes(MlirContext ctx) {
mlirIntegerTypeGet(ctx, 64)};
MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
if (mlirFunctionTypeGetNumInputs(funcType) != 2)
- return 21;
+ return 23;
if (mlirFunctionTypeGetNumResults(funcType) != 3)
- return 22;
+ return 24;
if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
!mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
- return 23;
+ return 25;
if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
!mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
!mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
- return 24;
+ return 26;
mlirTypeDump(funcType);
fprintf(stderr, "\n");
// CHECK: (index, i1) -> (i16, i32, i64)
@@ -829,7 +843,7 @@ static int printBuiltinTypes(MlirContext ctx) {
!mlirStringRefEqual(mlirOpaqueTypeGetDialectNamespace(opaque),
namespace) ||
!mlirStringRefEqual(mlirOpaqueTypeGetData(opaque), data))
- return 25;
+ return 27;
mlirTypeDump(opaque);
fprintf(stderr, "\n");
// CHECK: !dialect.type
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 672418b5383ae45..e2344794c839a3a 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -300,7 +300,7 @@ def testVectorType():
none = NoneType.get()
try:
- vector_invalid = VectorType.get(shape, none)
+ VectorType.get(shape, none)
except MLIRError as e:
# CHECK: Invalid type:
# CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
@@ -308,6 +308,27 @@ def testVectorType():
else:
print("Exception not produced")
+ scalable_1 = VectorType.get(shape, f32, scalable=[False, True])
+ scalable_2 = VectorType.get([2, 3, 4],
+ f32,
+ scalable=[True, False, True])
+ assert scalable_1.scalable
+ assert scalable_2.scalable
+ assert scalable_1.scalable_dims == [False, True]
+ assert scalable_2.scalable_dims == [True, False, True]
+ # CHECK: scalable 1: vector<2x[3]xf32>
+ print("scalable 1: ", scalable_1)
+ # CHECK: scalable 2: vector<[2]x3x[4]xf32>
+ print("scalable 2: ", scalable_2)
+
+ try:
+ VectorType.get(shape, f32, scalable=[False, True, True])
+ except ValueError as e:
+ # CHECK: Expected len(scalable) == len(shape).
+ print(e)
+ else:
+ print("Exception not produced")
+
# CHECK-LABEL: TEST: testRankedTensorType
@run
|
✅ With the latest revision this PR passed the Python code formatter. |
except MLIRError as e: | ||
# CHECK: Invalid type: | ||
# CHECK: error: unknown: vector elements must be int/index/float type but got 'none' | ||
print(e) | ||
else: | ||
print("Exception not produced") | ||
|
||
scalable_1 = VectorType.get(shape, f32, scalable=[False, True]) | ||
scalable_2 = VectorType.get([2, 3, 4], f32, scalable=[True, False, True]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Base 10 is better than base 2 😃 i.e., it would be better if the python api just required the user to list the dims that are scalable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is consistent with how C++ constructors are implemented. I prefer consistency over niceness here, one can always add a wrapper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay that reminds me to upstream https://github.com/makslevental/mlir-python-utils/blob/main/mlir/utils/types.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a separate keyword for this, it's Python after all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
mlir/lib/Bindings/Python/IRTypes.cpp
Outdated
|
||
// Vector-of-bool may be using bit packing, so we cannot access its | ||
// data directly. Explicitly create an array-of-bool instead. | ||
bool *scalableData = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Usually if I were resorting to low level alloc, I'd go all the way to alloca for something like this. Otherwise, I would use new[]/delete[].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would SmallVector<bool>
work here? I don't think it has a bool specialization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good points, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix and apologies for missing this, LGTM!
mlir/lib/Bindings/Python/IRTypes.cpp
Outdated
if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0) { | ||
throw py::value_error("Scalable dimension index out of bounds."); | ||
scalableDimFlags[dim] = true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should that scalableDimFlags[dim] = true;
be outside the if
? It looks like it's unreachable due to the throw
before it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, thank you!
The scalable dimension functionality was added to the vector type after the bindings for it were defined, without the bindings being ever updated. Fix that.
Awesome, thanks for getting this merged! 😄 |
The scalable dimension functionality was added to the vector type after the bindings for it were defined, without the bindings being ever updated. Fix that.