Skip to content

[mlir] support scalable vectors in python bindings #71050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 6, 2023
Merged

Conversation

ftynse
Copy link
Member

@ftynse ftynse commented Nov 2, 2023

The scalable dimension functionality was added to the vector type after the bindings for it were defined, without the bindings being ever updated. Fix that.

@llvmbot
Copy link
Member

llvmbot commented Nov 2, 2023

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

Changes

The scalable dimension functionality was added to the vector type after the bindings for it were defined, without the bindings being ever updated. Fix that.


Full diff: https://github.com/llvm/llvm-project/pull/71050.diff

5 Files Affected:

  • (modified) mlir/include/mlir-c/BuiltinTypes.h (+26)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+42-12)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+25)
  • (modified) mlir/test/CAPI/ir.c (+24-10)
  • (modified) mlir/test/python/ir/builtin_types.py (+22-1)
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index a6d8e10efbde923..1fd5691f41eec35 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -271,6 +271,32 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
                                                      const int64_t *shape,
                                                      MlirType elementType);
 
+/// Creates a scalable vector type with the shape identified by its rank and
+/// dimensions. A subset of dimensions may be marked as scalable via the
+/// corresponding flag list, which is expected to have as many entries as the
+/// rank of the vector. The vector is created in the same context as the element
+/// type.
+MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalable(intptr_t rank,
+                                                      const int64_t *shape,
+                                                      const bool *scalable,
+                                                      MlirType elementType);
+
+/// Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType
+/// on illegal arguments, emitting appropriate diagnostics.
+MLIR_CAPI_EXPORTED
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+                                          const int64_t *shape,
+                                          const bool *scalable,
+                                          MlirType elementType);
+
+/// Checks whether the given vector type is scalable, i.e., has at least one
+/// scalable dimension.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type);
+
+/// Checks whether the "dim"-th dimension of the given vector is scalable.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type,
+                                                    intptr_t dim);
+
 //===----------------------------------------------------------------------===//
 // Ranked / Unranked Tensor type.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index a7ccfbea542f5c7..e145e05ad9b4f19 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
+#include "llvm/ADT/ScopeExit.h"
 #include <optional>
 
 namespace py = pybind11;
@@ -463,18 +464,47 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
-        "get",
-        [](std::vector<int64_t> shape, PyType &elementType,
-           DefaultingPyLocation loc) {
-          PyMlirContext::ErrorCapture errors(loc->getContext());
-          MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
-                                                elementType);
-          if (mlirTypeIsNull(t))
-            throw MLIRError("Invalid type", errors.take());
-          return PyVectorType(elementType.getContext(), t);
-        },
-        py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
-        "Create a vector type");
+         "get",
+         [](std::vector<int64_t> shape, PyType &elementType,
+            std::optional<py::list> scalable, DefaultingPyLocation loc) {
+           PyMlirContext::ErrorCapture errors(loc->getContext());
+           MlirType type;
+           if (scalable) {
+             if (scalable->size() != shape.size())
+               throw py::value_error("Expected len(scalable) == len(shape).");
+
+             // Vector-of-bool may be using bit packing, so we cannot access its
+             // data directly. Explicitly create an array-of-bool instead.
+             bool *scalableData =
+                 static_cast<bool *>(malloc(sizeof(bool) * scalable->size()));
+             auto deleter = llvm::make_scope_exit([&] { free(scalableData); });
+             auto range = llvm::map_range(
+                 *scalable, [](const py::handle &h) { return h.cast<bool>(); });
+             llvm::copy(range, scalableData);
+             type = mlirVectorTypeGetScalableChecked(
+                 loc, shape.size(), shape.data(), scalableData, elementType);
+           } else {
+             type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+                                             elementType);
+           }
+           if (mlirTypeIsNull(type))
+             throw MLIRError("Invalid type", errors.take());
+           return PyVectorType(elementType.getContext(), type);
+         },
+         py::arg("shape"), py::arg("elementType"), py::kw_only(),
+         py::arg("scalable") = py::none(), py::arg("loc") = py::none(),
+         "Create a vector type")
+        .def_property_readonly(
+            "scalable",
+            [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+        .def_property_readonly("scalable_dims", [](MlirType self) {
+          std::vector<bool> scalableDims;
+          size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+          scalableDims.reserve(rank);
+          for (size_t i = 0; i < rank; ++i)
+            scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+          return scalableDims;
+        });
   }
 };
 
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 50266b4b5233235..6e645188dac8616 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -281,6 +281,31 @@ MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
       unwrap(elementType)));
 }
 
+MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
+                                   const bool *scalable, MlirType elementType) {
+  return wrap(VectorType::get(
+      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
+      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+                                          const int64_t *shape,
+                                          const bool *scalable,
+                                          MlirType elementType) {
+  return wrap(VectorType::getChecked(
+      unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+      unwrap(elementType),
+      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+bool mlirVectorTypeIsScalable(MlirType type) {
+  return unwrap(type).cast<VectorType>().isScalable();
+}
+
+bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
+  return unwrap(type).cast<VectorType>().getScalableDims()[dim];
+}
+
 //===----------------------------------------------------------------------===//
 // Ranked / Unranked tensor type.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index c6425f80a8bce9c..3a2bdb9bfc93334 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -743,13 +743,27 @@ static int printBuiltinTypes(MlirContext ctx) {
   fprintf(stderr, "\n");
   // CHECK: vector<2x3xf32>
 
+  // Scalable vector type.
+  bool scalable[] = {false, true};
+  MlirType scalableVector = mlirVectorTypeGetScalable(
+      sizeof(shape) / sizeof(int64_t), shape, scalable, f32);
+  if (!mlirTypeIsAVector(scalableVector))
+    return 16;
+  if (!mlirVectorTypeIsScalable(scalableVector) ||
+      mlirVectorTypeIsDimScalable(scalableVector, 0) ||
+      !mlirVectorTypeIsDimScalable(scalableVector, 1))
+    return 17;
+  mlirTypeDump(scalableVector);
+  fprintf(stderr, "\n");
+  // CHECK: vector<2x[3]xf32>
+
   // Ranked tensor type.
   MlirType rankedTensor = mlirRankedTensorTypeGet(
       sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
   if (!mlirTypeIsATensor(rankedTensor) ||
       !mlirTypeIsARankedTensor(rankedTensor) ||
       !mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
-    return 16;
+    return 18;
   mlirTypeDump(rankedTensor);
   fprintf(stderr, "\n");
   // CHECK: tensor<2x3xf32>
@@ -759,7 +773,7 @@ static int printBuiltinTypes(MlirContext ctx) {
   if (!mlirTypeIsATensor(unrankedTensor) ||
       !mlirTypeIsAUnrankedTensor(unrankedTensor) ||
       mlirShapedTypeHasRank(unrankedTensor))
-    return 17;
+    return 19;
   mlirTypeDump(unrankedTensor);
   fprintf(stderr, "\n");
   // CHECK: tensor<*xf32>
@@ -770,7 +784,7 @@ static int printBuiltinTypes(MlirContext ctx) {
       f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
   if (!mlirTypeIsAMemRef(memRef) ||
       !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
-    return 18;
+    return 20;
   mlirTypeDump(memRef);
   fprintf(stderr, "\n");
   // CHECK: memref<2x3xf32, 2>
@@ -782,7 +796,7 @@ static int printBuiltinTypes(MlirContext ctx) {
       mlirTypeIsAMemRef(unrankedMemRef) ||
       !mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
                           memSpace4))
-    return 19;
+    return 21;
   mlirTypeDump(unrankedMemRef);
   fprintf(stderr, "\n");
   // CHECK: memref<*xf32, 4>
@@ -793,7 +807,7 @@ static int printBuiltinTypes(MlirContext ctx) {
   if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
-    return 20;
+    return 22;
   mlirTypeDump(tuple);
   fprintf(stderr, "\n");
   // CHECK: tuple<memref<*xf32, 4>, f32>
@@ -805,16 +819,16 @@ static int printBuiltinTypes(MlirContext ctx) {
                              mlirIntegerTypeGet(ctx, 64)};
   MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
   if (mlirFunctionTypeGetNumInputs(funcType) != 2)
-    return 21;
+    return 23;
   if (mlirFunctionTypeGetNumResults(funcType) != 3)
-    return 22;
+    return 24;
   if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
       !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
-    return 23;
+    return 25;
   if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
       !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
       !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
-    return 24;
+    return 26;
   mlirTypeDump(funcType);
   fprintf(stderr, "\n");
   // CHECK: (index, i1) -> (i16, i32, i64)
@@ -829,7 +843,7 @@ static int printBuiltinTypes(MlirContext ctx) {
       !mlirStringRefEqual(mlirOpaqueTypeGetDialectNamespace(opaque),
                           namespace) ||
       !mlirStringRefEqual(mlirOpaqueTypeGetData(opaque), data))
-    return 25;
+    return 27;
   mlirTypeDump(opaque);
   fprintf(stderr, "\n");
   // CHECK: !dialect.type
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 672418b5383ae45..e2344794c839a3a 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -300,7 +300,7 @@ def testVectorType():
 
         none = NoneType.get()
         try:
-            vector_invalid = VectorType.get(shape, none)
+            VectorType.get(shape, none)
         except MLIRError as e:
             # CHECK: Invalid type:
             # CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
@@ -308,6 +308,27 @@ def testVectorType():
         else:
             print("Exception not produced")
 
+        scalable_1 = VectorType.get(shape, f32, scalable=[False, True])
+        scalable_2 = VectorType.get([2, 3, 4],
+                                    f32,
+                                    scalable=[True, False, True])
+        assert scalable_1.scalable
+        assert scalable_2.scalable
+        assert scalable_1.scalable_dims == [False, True]
+        assert scalable_2.scalable_dims == [True, False, True]
+        # CHECK: scalable 1: vector<2x[3]xf32>
+        print("scalable 1: ", scalable_1)
+        # CHECK: scalable 2: vector<[2]x3x[4]xf32>
+        print("scalable 2: ", scalable_2)
+
+        try:
+            VectorType.get(shape, f32, scalable=[False, True, True])
+        except ValueError as e:
+            # CHECK: Expected len(scalable) == len(shape).
+            print(e)
+        else:
+            print("Exception not produced")
+
 
 # CHECK-LABEL: TEST: testRankedTensorType
 @run

Copy link

github-actions bot commented Nov 2, 2023

✅ With the latest revision this PR passed the Python code formatter.

except MLIRError as e:
# CHECK: Invalid type:
# CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
print(e)
else:
print("Exception not produced")

scalable_1 = VectorType.get(shape, f32, scalable=[False, True])
scalable_2 = VectorType.get([2, 3, 4], f32, scalable=[True, False, True])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base 10 is better than base 2 😃 i.e., it would be better if the python api just required the user to list the dims that are scalable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is consistent with how C++ constructors are implemented. I prefer consistency over niceness here, one can always add a wrapper.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a separate keyword for this, it's Python after all.

@makslevental makslevental self-requested a review November 2, 2023 15:57
Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


// Vector-of-bool may be using bit packing, so we cannot access its
// data directly. Explicitly create an array-of-bool instead.
bool *scalableData =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Usually if I were resorting to low level alloc, I'd go all the way to alloca for something like this. Otherwise, I would use new[]/delete[].

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would SmallVector<bool> work here? I don't think it has a bool specialization

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points, thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix and apologies for missing this, LGTM!

Comment on lines 507 to 510
if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0) {
throw py::value_error("Scalable dimension index out of bounds.");
scalableDimFlags[dim] = true;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should that scalableDimFlags[dim] = true; be outside the if? It looks like it's unreachable due to the throw before it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thank you!

The scalable dimension functionality was added to the vector type after
the bindings for it were defined, without the bindings being ever
updated. Fix that.
@ftynse ftynse merged commit 96dadc9 into llvm:main Nov 6, 2023
@ftynse ftynse deleted the scalable branch November 6, 2023 12:15
@michalt
Copy link
Contributor

michalt commented Nov 6, 2023

Awesome, thanks for getting this merged! 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants