Skip to content

Commit f0d72c5

Browse files
committed
[mlir] support scalable vectors in python bindings
The scalable dimension functionality was added to the vector type after the bindings for it were defined, without the bindings being ever updated. Fix that.
1 parent 32521bb commit f0d72c5

File tree

5 files changed

+172
-25
lines changed

5 files changed

+172
-25
lines changed

mlir/include/mlir-c/BuiltinTypes.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,32 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
271271
const int64_t *shape,
272272
MlirType elementType);
273273

274+
/// Creates a scalable vector type with the shape identified by its rank and
275+
/// dimensions. A subset of dimensions may be marked as scalable via the
276+
/// corresponding flag list, which is expected to have as many entries as the
277+
/// rank of the vector. The vector is created in the same context as the element
278+
/// type.
279+
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalable(intptr_t rank,
280+
const int64_t *shape,
281+
const bool *scalable,
282+
MlirType elementType);
283+
284+
/// Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType
285+
/// on illegal arguments, emitting appropriate diagnostics.
286+
MLIR_CAPI_EXPORTED
287+
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
288+
const int64_t *shape,
289+
const bool *scalable,
290+
MlirType elementType);
291+
292+
/// Checks whether the given vector type is scalable, i.e., has at least one
293+
/// scalable dimension.
294+
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type);
295+
296+
/// Checks whether the "dim"-th dimension of the given vector is scalable.
297+
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type,
298+
intptr_t dim);
299+
274300
//===----------------------------------------------------------------------===//
275301
// Ranked / Unranked Tensor type.
276302
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -462,19 +462,62 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
462462
using PyConcreteType::PyConcreteType;
463463

464464
static void bindDerived(ClassTy &c) {
465-
c.def_static(
466-
"get",
467-
[](std::vector<int64_t> shape, PyType &elementType,
468-
DefaultingPyLocation loc) {
469-
PyMlirContext::ErrorCapture errors(loc->getContext());
470-
MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
471-
elementType);
472-
if (mlirTypeIsNull(t))
473-
throw MLIRError("Invalid type", errors.take());
474-
return PyVectorType(elementType.getContext(), t);
475-
},
476-
py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
477-
"Create a vector type");
465+
c.def_static("get", &PyVectorType::get, py::arg("shape"),
466+
py::arg("elementType"), py::kw_only(),
467+
py::arg("scalable") = py::none(),
468+
py::arg("scalable_dims") = py::none(),
469+
py::arg("loc") = py::none(), "Create a vector type")
470+
.def_property_readonly(
471+
"scalable",
472+
[](MlirType self) { return mlirVectorTypeIsScalable(self); })
473+
.def_property_readonly("scalable_dims", [](MlirType self) {
474+
std::vector<bool> scalableDims;
475+
size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
476+
scalableDims.reserve(rank);
477+
for (size_t i = 0; i < rank; ++i)
478+
scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
479+
return scalableDims;
480+
});
481+
}
482+
483+
private:
484+
static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
485+
std::optional<py::list> scalable,
486+
std::optional<std::vector<int64_t>> scalableDims,
487+
DefaultingPyLocation loc) {
488+
if (scalable && scalableDims) {
489+
throw py::value_error("'scalable' and 'scalable_dims' kwargs "
490+
"are mutually exclusive.");
491+
}
492+
493+
PyMlirContext::ErrorCapture errors(loc->getContext());
494+
MlirType type;
495+
if (scalable) {
496+
if (scalable->size() != shape.size())
497+
throw py::value_error("Expected len(scalable) == len(shape).");
498+
499+
SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
500+
*scalable, [](const py::handle &h) { return h.cast<bool>(); }));
501+
type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
502+
scalableDimFlags.data(),
503+
elementType);
504+
} else if (scalableDims) {
505+
SmallVector<bool> scalableDimFlags(shape.size(), false);
506+
for (int64_t dim : *scalableDims) {
507+
if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
508+
throw py::value_error("Scalable dimension index out of bounds.");
509+
scalableDimFlags[dim] = true;
510+
}
511+
type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
512+
scalableDimFlags.data(),
513+
elementType);
514+
} else {
515+
type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
516+
elementType);
517+
}
518+
if (mlirTypeIsNull(type))
519+
throw MLIRError("Invalid type", errors.take());
520+
return PyVectorType(elementType.getContext(), type);
478521
}
479522
};
480523

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,31 @@ MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
281281
unwrap(elementType)));
282282
}
283283

284+
MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
285+
const bool *scalable, MlirType elementType) {
286+
return wrap(VectorType::get(
287+
llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
288+
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
289+
}
290+
291+
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
292+
const int64_t *shape,
293+
const bool *scalable,
294+
MlirType elementType) {
295+
return wrap(VectorType::getChecked(
296+
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
297+
unwrap(elementType),
298+
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
299+
}
300+
301+
bool mlirVectorTypeIsScalable(MlirType type) {
302+
return unwrap(type).cast<VectorType>().isScalable();
303+
}
304+
305+
bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
306+
return unwrap(type).cast<VectorType>().getScalableDims()[dim];
307+
}
308+
284309
//===----------------------------------------------------------------------===//
285310
// Ranked / Unranked tensor type.
286311
//===----------------------------------------------------------------------===//

mlir/test/CAPI/ir.c

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -743,13 +743,27 @@ static int printBuiltinTypes(MlirContext ctx) {
743743
fprintf(stderr, "\n");
744744
// CHECK: vector<2x3xf32>
745745

746+
// Scalable vector type.
747+
bool scalable[] = {false, true};
748+
MlirType scalableVector = mlirVectorTypeGetScalable(
749+
sizeof(shape) / sizeof(int64_t), shape, scalable, f32);
750+
if (!mlirTypeIsAVector(scalableVector))
751+
return 16;
752+
if (!mlirVectorTypeIsScalable(scalableVector) ||
753+
mlirVectorTypeIsDimScalable(scalableVector, 0) ||
754+
!mlirVectorTypeIsDimScalable(scalableVector, 1))
755+
return 17;
756+
mlirTypeDump(scalableVector);
757+
fprintf(stderr, "\n");
758+
// CHECK: vector<2x[3]xf32>
759+
746760
// Ranked tensor type.
747761
MlirType rankedTensor = mlirRankedTensorTypeGet(
748762
sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
749763
if (!mlirTypeIsATensor(rankedTensor) ||
750764
!mlirTypeIsARankedTensor(rankedTensor) ||
751765
!mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
752-
return 16;
766+
return 18;
753767
mlirTypeDump(rankedTensor);
754768
fprintf(stderr, "\n");
755769
// CHECK: tensor<2x3xf32>
@@ -759,7 +773,7 @@ static int printBuiltinTypes(MlirContext ctx) {
759773
if (!mlirTypeIsATensor(unrankedTensor) ||
760774
!mlirTypeIsAUnrankedTensor(unrankedTensor) ||
761775
mlirShapedTypeHasRank(unrankedTensor))
762-
return 17;
776+
return 19;
763777
mlirTypeDump(unrankedTensor);
764778
fprintf(stderr, "\n");
765779
// CHECK: tensor<*xf32>
@@ -770,7 +784,7 @@ static int printBuiltinTypes(MlirContext ctx) {
770784
f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
771785
if (!mlirTypeIsAMemRef(memRef) ||
772786
!mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
773-
return 18;
787+
return 20;
774788
mlirTypeDump(memRef);
775789
fprintf(stderr, "\n");
776790
// CHECK: memref<2x3xf32, 2>
@@ -782,7 +796,7 @@ static int printBuiltinTypes(MlirContext ctx) {
782796
mlirTypeIsAMemRef(unrankedMemRef) ||
783797
!mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
784798
memSpace4))
785-
return 19;
799+
return 21;
786800
mlirTypeDump(unrankedMemRef);
787801
fprintf(stderr, "\n");
788802
// CHECK: memref<*xf32, 4>
@@ -793,7 +807,7 @@ static int printBuiltinTypes(MlirContext ctx) {
793807
if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
794808
!mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
795809
!mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
796-
return 20;
810+
return 22;
797811
mlirTypeDump(tuple);
798812
fprintf(stderr, "\n");
799813
// CHECK: tuple<memref<*xf32, 4>, f32>
@@ -805,16 +819,16 @@ static int printBuiltinTypes(MlirContext ctx) {
805819
mlirIntegerTypeGet(ctx, 64)};
806820
MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
807821
if (mlirFunctionTypeGetNumInputs(funcType) != 2)
808-
return 21;
822+
return 23;
809823
if (mlirFunctionTypeGetNumResults(funcType) != 3)
810-
return 22;
824+
return 24;
811825
if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
812826
!mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
813-
return 23;
827+
return 25;
814828
if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
815829
!mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
816830
!mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
817-
return 24;
831+
return 26;
818832
mlirTypeDump(funcType);
819833
fprintf(stderr, "\n");
820834
// CHECK: (index, i1) -> (i16, i32, i64)
@@ -829,7 +843,7 @@ static int printBuiltinTypes(MlirContext ctx) {
829843
!mlirStringRefEqual(mlirOpaqueTypeGetDialectNamespace(opaque),
830844
namespace) ||
831845
!mlirStringRefEqual(mlirOpaqueTypeGetData(opaque), data))
832-
return 25;
846+
return 27;
833847
mlirTypeDump(opaque);
834848
fprintf(stderr, "\n");
835849
// CHECK: !dialect.type

mlir/test/python/ir/builtin_types.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,54 @@ def testVectorType():
300300

301301
none = NoneType.get()
302302
try:
303-
vector_invalid = VectorType.get(shape, none)
303+
VectorType.get(shape, none)
304304
except MLIRError as e:
305305
# CHECK: Invalid type:
306306
# CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
307307
print(e)
308308
else:
309309
print("Exception not produced")
310310

311+
scalable_1 = VectorType.get(shape, f32, scalable=[False, True])
312+
scalable_2 = VectorType.get([2, 3, 4], f32, scalable=[True, False, True])
313+
assert scalable_1.scalable
314+
assert scalable_2.scalable
315+
assert scalable_1.scalable_dims == [False, True]
316+
assert scalable_2.scalable_dims == [True, False, True]
317+
# CHECK: scalable 1: vector<2x[3]xf32>
318+
print("scalable 1: ", scalable_1)
319+
# CHECK: scalable 2: vector<[2]x3x[4]xf32>
320+
print("scalable 2: ", scalable_2)
321+
322+
scalable_3 = VectorType.get(shape, f32, scalable_dims=[1])
323+
scalable_4 = VectorType.get([2, 3, 4], f32, scalable_dims=[0, 2])
324+
assert scalable_3 == scalable_1
325+
assert scalable_4 == scalable_2
326+
327+
try:
328+
VectorType.get(shape, f32, scalable=[False, True, True])
329+
except ValueError as e:
330+
# CHECK: Expected len(scalable) == len(shape).
331+
print(e)
332+
else:
333+
print("Exception not produced")
334+
335+
try:
336+
VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[1])
337+
except ValueError as e:
338+
# CHECK: kwargs are mutually exclusive.
339+
print(e)
340+
else:
341+
print("Exception not produced")
342+
343+
try:
344+
VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[42])
345+
except ValueError as e:
346+
# CHECK: Scalable dimension index out of bounds.
347+
print(e)
348+
else:
349+
print("Exception not produced")
350+
311351

312352
# CHECK-LABEL: TEST: testRankedTensorType
313353
@run
@@ -337,7 +377,6 @@ def testRankedTensorType():
337377
assert RankedTensorType.get(shape, f32).encoding is None
338378

339379

340-
341380
# CHECK-LABEL: TEST: testUnrankedTensorType
342381
@run
343382
def testUnrankedTensorType():

0 commit comments

Comments
 (0)