Skip to content

Commit 96dadc9

Browse files
authored
[mlir] support scalable vectors in python bindings (#71050)
The scalable dimension functionality was added to the vector type after the bindings for it were defined, without the bindings being ever updated. Fix that.
1 parent ff67e85 commit 96dadc9

File tree

5 files changed

+172
-25
lines changed

5 files changed

+172
-25
lines changed

mlir/include/mlir-c/BuiltinTypes.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,32 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
271271
const int64_t *shape,
272272
MlirType elementType);
273273

274+
/// Creates a scalable vector type with the shape identified by its rank and
275+
/// dimensions. A subset of dimensions may be marked as scalable via the
276+
/// corresponding flag list, which is expected to have as many entries as the
277+
/// rank of the vector. The vector is created in the same context as the element
278+
/// type.
279+
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalable(intptr_t rank,
280+
const int64_t *shape,
281+
const bool *scalable,
282+
MlirType elementType);
283+
284+
/// Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType
285+
/// on illegal arguments, emitting appropriate diagnostics.
286+
MLIR_CAPI_EXPORTED
287+
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
288+
const int64_t *shape,
289+
const bool *scalable,
290+
MlirType elementType);
291+
292+
/// Checks whether the given vector type is scalable, i.e., has at least one
293+
/// scalable dimension.
294+
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type);
295+
296+
/// Checks whether the "dim"-th dimension of the given vector is scalable.
297+
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type,
298+
intptr_t dim);
299+
274300
//===----------------------------------------------------------------------===//
275301
// Ranked / Unranked Tensor type.
276302
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -462,19 +462,62 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
462462
using PyConcreteType::PyConcreteType;
463463

464464
static void bindDerived(ClassTy &c) {
465-
c.def_static(
466-
"get",
467-
[](std::vector<int64_t> shape, PyType &elementType,
468-
DefaultingPyLocation loc) {
469-
PyMlirContext::ErrorCapture errors(loc->getContext());
470-
MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
471-
elementType);
472-
if (mlirTypeIsNull(t))
473-
throw MLIRError("Invalid type", errors.take());
474-
return PyVectorType(elementType.getContext(), t);
475-
},
476-
py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
477-
"Create a vector type");
465+
c.def_static("get", &PyVectorType::get, py::arg("shape"),
466+
py::arg("elementType"), py::kw_only(),
467+
py::arg("scalable") = py::none(),
468+
py::arg("scalable_dims") = py::none(),
469+
py::arg("loc") = py::none(), "Create a vector type")
470+
.def_property_readonly(
471+
"scalable",
472+
[](MlirType self) { return mlirVectorTypeIsScalable(self); })
473+
.def_property_readonly("scalable_dims", [](MlirType self) {
474+
std::vector<bool> scalableDims;
475+
size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
476+
scalableDims.reserve(rank);
477+
for (size_t i = 0; i < rank; ++i)
478+
scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
479+
return scalableDims;
480+
});
481+
}
482+
483+
private:
484+
static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
485+
std::optional<py::list> scalable,
486+
std::optional<std::vector<int64_t>> scalableDims,
487+
DefaultingPyLocation loc) {
488+
if (scalable && scalableDims) {
489+
throw py::value_error("'scalable' and 'scalable_dims' kwargs "
490+
"are mutually exclusive.");
491+
}
492+
493+
PyMlirContext::ErrorCapture errors(loc->getContext());
494+
MlirType type;
495+
if (scalable) {
496+
if (scalable->size() != shape.size())
497+
throw py::value_error("Expected len(scalable) == len(shape).");
498+
499+
SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
500+
*scalable, [](const py::handle &h) { return h.cast<bool>(); }));
501+
type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
502+
scalableDimFlags.data(),
503+
elementType);
504+
} else if (scalableDims) {
505+
SmallVector<bool> scalableDimFlags(shape.size(), false);
506+
for (int64_t dim : *scalableDims) {
507+
if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
508+
throw py::value_error("Scalable dimension index out of bounds.");
509+
scalableDimFlags[dim] = true;
510+
}
511+
type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
512+
scalableDimFlags.data(),
513+
elementType);
514+
} else {
515+
type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
516+
elementType);
517+
}
518+
if (mlirTypeIsNull(type))
519+
throw MLIRError("Invalid type", errors.take());
520+
return PyVectorType(elementType.getContext(), type);
478521
}
479522
};
480523

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,31 @@ MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
281281
unwrap(elementType)));
282282
}
283283

284+
MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
285+
const bool *scalable, MlirType elementType) {
286+
return wrap(VectorType::get(
287+
llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
288+
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
289+
}
290+
291+
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
292+
const int64_t *shape,
293+
const bool *scalable,
294+
MlirType elementType) {
295+
return wrap(VectorType::getChecked(
296+
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
297+
unwrap(elementType),
298+
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
299+
}
300+
301+
bool mlirVectorTypeIsScalable(MlirType type) {
302+
return unwrap(type).cast<VectorType>().isScalable();
303+
}
304+
305+
bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
306+
return unwrap(type).cast<VectorType>().getScalableDims()[dim];
307+
}
308+
284309
//===----------------------------------------------------------------------===//
285310
// Ranked / Unranked tensor type.
286311
//===----------------------------------------------------------------------===//

mlir/test/CAPI/ir.c

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -746,13 +746,27 @@ static int printBuiltinTypes(MlirContext ctx) {
746746
fprintf(stderr, "\n");
747747
// CHECK: vector<2x3xf32>
748748

749+
// Scalable vector type.
750+
bool scalable[] = {false, true};
751+
MlirType scalableVector = mlirVectorTypeGetScalable(
752+
sizeof(shape) / sizeof(int64_t), shape, scalable, f32);
753+
if (!mlirTypeIsAVector(scalableVector))
754+
return 16;
755+
if (!mlirVectorTypeIsScalable(scalableVector) ||
756+
mlirVectorTypeIsDimScalable(scalableVector, 0) ||
757+
!mlirVectorTypeIsDimScalable(scalableVector, 1))
758+
return 17;
759+
mlirTypeDump(scalableVector);
760+
fprintf(stderr, "\n");
761+
// CHECK: vector<2x[3]xf32>
762+
749763
// Ranked tensor type.
750764
MlirType rankedTensor = mlirRankedTensorTypeGet(
751765
sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
752766
if (!mlirTypeIsATensor(rankedTensor) ||
753767
!mlirTypeIsARankedTensor(rankedTensor) ||
754768
!mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
755-
return 16;
769+
return 18;
756770
mlirTypeDump(rankedTensor);
757771
fprintf(stderr, "\n");
758772
// CHECK: tensor<2x3xf32>
@@ -762,7 +776,7 @@ static int printBuiltinTypes(MlirContext ctx) {
762776
if (!mlirTypeIsATensor(unrankedTensor) ||
763777
!mlirTypeIsAUnrankedTensor(unrankedTensor) ||
764778
mlirShapedTypeHasRank(unrankedTensor))
765-
return 17;
779+
return 19;
766780
mlirTypeDump(unrankedTensor);
767781
fprintf(stderr, "\n");
768782
// CHECK: tensor<*xf32>
@@ -773,7 +787,7 @@ static int printBuiltinTypes(MlirContext ctx) {
773787
f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
774788
if (!mlirTypeIsAMemRef(memRef) ||
775789
!mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
776-
return 18;
790+
return 20;
777791
mlirTypeDump(memRef);
778792
fprintf(stderr, "\n");
779793
// CHECK: memref<2x3xf32, 2>
@@ -785,7 +799,7 @@ static int printBuiltinTypes(MlirContext ctx) {
785799
mlirTypeIsAMemRef(unrankedMemRef) ||
786800
!mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
787801
memSpace4))
788-
return 19;
802+
return 21;
789803
mlirTypeDump(unrankedMemRef);
790804
fprintf(stderr, "\n");
791805
// CHECK: memref<*xf32, 4>
@@ -796,7 +810,7 @@ static int printBuiltinTypes(MlirContext ctx) {
796810
if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
797811
!mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
798812
!mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
799-
return 20;
813+
return 22;
800814
mlirTypeDump(tuple);
801815
fprintf(stderr, "\n");
802816
// CHECK: tuple<memref<*xf32, 4>, f32>
@@ -808,16 +822,16 @@ static int printBuiltinTypes(MlirContext ctx) {
808822
mlirIntegerTypeGet(ctx, 64)};
809823
MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
810824
if (mlirFunctionTypeGetNumInputs(funcType) != 2)
811-
return 21;
825+
return 23;
812826
if (mlirFunctionTypeGetNumResults(funcType) != 3)
813-
return 22;
827+
return 24;
814828
if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
815829
!mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
816-
return 23;
830+
return 25;
817831
if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
818832
!mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
819833
!mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
820-
return 24;
834+
return 26;
821835
mlirTypeDump(funcType);
822836
fprintf(stderr, "\n");
823837
// CHECK: (index, i1) -> (i16, i32, i64)
@@ -832,7 +846,7 @@ static int printBuiltinTypes(MlirContext ctx) {
832846
!mlirStringRefEqual(mlirOpaqueTypeGetDialectNamespace(opaque),
833847
namespace) ||
834848
!mlirStringRefEqual(mlirOpaqueTypeGetData(opaque), data))
835-
return 25;
849+
return 27;
836850
mlirTypeDump(opaque);
837851
fprintf(stderr, "\n");
838852
// CHECK: !dialect.type

mlir/test/python/ir/builtin_types.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,54 @@ def testVectorType():
300300

301301
none = NoneType.get()
302302
try:
303-
vector_invalid = VectorType.get(shape, none)
303+
VectorType.get(shape, none)
304304
except MLIRError as e:
305305
# CHECK: Invalid type:
306306
# CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
307307
print(e)
308308
else:
309309
print("Exception not produced")
310310

311+
scalable_1 = VectorType.get(shape, f32, scalable=[False, True])
312+
scalable_2 = VectorType.get([2, 3, 4], f32, scalable=[True, False, True])
313+
assert scalable_1.scalable
314+
assert scalable_2.scalable
315+
assert scalable_1.scalable_dims == [False, True]
316+
assert scalable_2.scalable_dims == [True, False, True]
317+
# CHECK: scalable 1: vector<2x[3]xf32>
318+
print("scalable 1: ", scalable_1)
319+
# CHECK: scalable 2: vector<[2]x3x[4]xf32>
320+
print("scalable 2: ", scalable_2)
321+
322+
scalable_3 = VectorType.get(shape, f32, scalable_dims=[1])
323+
scalable_4 = VectorType.get([2, 3, 4], f32, scalable_dims=[0, 2])
324+
assert scalable_3 == scalable_1
325+
assert scalable_4 == scalable_2
326+
327+
try:
328+
VectorType.get(shape, f32, scalable=[False, True, True])
329+
except ValueError as e:
330+
# CHECK: Expected len(scalable) == len(shape).
331+
print(e)
332+
else:
333+
print("Exception not produced")
334+
335+
try:
336+
VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[1])
337+
except ValueError as e:
338+
# CHECK: kwargs are mutually exclusive.
339+
print(e)
340+
else:
341+
print("Exception not produced")
342+
343+
try:
344+
VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[42])
345+
except ValueError as e:
346+
# CHECK: Scalable dimension index out of bounds.
347+
print(e)
348+
else:
349+
print("Exception not produced")
350+
311351

312352
# CHECK-LABEL: TEST: testRankedTensorType
313353
@run
@@ -337,7 +377,6 @@ def testRankedTensorType():
337377
assert RankedTensorType.get(shape, f32).encoding is None
338378

339379

340-
341380
# CHECK-LABEL: TEST: testUnrankedTensorType
342381
@run
343382
def testUnrankedTensorType():

0 commit comments

Comments
 (0)