Skip to content

[mlir] support scalable vectors in python bindings #71050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions mlir/include/mlir-c/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,32 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
const int64_t *shape,
MlirType elementType);

/// Creates a scalable vector type with the shape identified by its rank and
/// dimensions. A subset of dimensions may be marked as scalable via the
/// corresponding flag list, which is expected to have as many entries as the
/// rank of the vector. The vector is created in the same context as the element
/// type.
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalable(intptr_t rank,
const int64_t *shape,
const bool *scalable,
MlirType elementType);

/// Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType
/// on illegal arguments, emitting appropriate diagnostics.
MLIR_CAPI_EXPORTED
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape,
const bool *scalable,
MlirType elementType);

/// Checks whether the given vector type is scalable, i.e., has at least one
/// scalable dimension.
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type);

/// Checks whether the "dim"-th dimension of the given vector is scalable.
MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type,
intptr_t dim);

//===----------------------------------------------------------------------===//
// Ranked / Unranked Tensor type.
//===----------------------------------------------------------------------===//
Expand Down
69 changes: 56 additions & 13 deletions mlir/lib/Bindings/Python/IRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,19 +462,62 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
using PyConcreteType::PyConcreteType;

static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](std::vector<int64_t> shape, PyType &elementType,
DefaultingPyLocation loc) {
PyMlirContext::ErrorCapture errors(loc->getContext());
MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
elementType);
if (mlirTypeIsNull(t))
throw MLIRError("Invalid type", errors.take());
return PyVectorType(elementType.getContext(), t);
},
py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
"Create a vector type");
c.def_static("get", &PyVectorType::get, py::arg("shape"),
py::arg("elementType"), py::kw_only(),
py::arg("scalable") = py::none(),
py::arg("scalable_dims") = py::none(),
py::arg("loc") = py::none(), "Create a vector type")
.def_property_readonly(
"scalable",
[](MlirType self) { return mlirVectorTypeIsScalable(self); })
.def_property_readonly("scalable_dims", [](MlirType self) {
std::vector<bool> scalableDims;
size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
scalableDims.reserve(rank);
for (size_t i = 0; i < rank; ++i)
scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
return scalableDims;
});
}

private:
static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
std::optional<py::list> scalable,
std::optional<std::vector<int64_t>> scalableDims,
DefaultingPyLocation loc) {
if (scalable && scalableDims) {
throw py::value_error("'scalable' and 'scalable_dims' kwargs "
"are mutually exclusive.");
}

PyMlirContext::ErrorCapture errors(loc->getContext());
MlirType type;
if (scalable) {
if (scalable->size() != shape.size())
throw py::value_error("Expected len(scalable) == len(shape).");

SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
*scalable, [](const py::handle &h) { return h.cast<bool>(); }));
type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
scalableDimFlags.data(),
elementType);
} else if (scalableDims) {
SmallVector<bool> scalableDimFlags(shape.size(), false);
for (int64_t dim : *scalableDims) {
if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
throw py::value_error("Scalable dimension index out of bounds.");
scalableDimFlags[dim] = true;
}
type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
scalableDimFlags.data(),
elementType);
} else {
type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
elementType);
}
if (mlirTypeIsNull(type))
throw MLIRError("Invalid type", errors.take());
return PyVectorType(elementType.getContext(), type);
}
};

Expand Down
25 changes: 25 additions & 0 deletions mlir/lib/CAPI/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,31 @@ MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
unwrap(elementType)));
}

MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
const bool *scalable, MlirType elementType) {
return wrap(VectorType::get(
llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}

MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape,
const bool *scalable,
MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType),
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}

bool mlirVectorTypeIsScalable(MlirType type) {
return unwrap(type).cast<VectorType>().isScalable();
}

bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
return unwrap(type).cast<VectorType>().getScalableDims()[dim];
}

//===----------------------------------------------------------------------===//
// Ranked / Unranked tensor type.
//===----------------------------------------------------------------------===//
Expand Down
34 changes: 24 additions & 10 deletions mlir/test/CAPI/ir.c
Original file line number Diff line number Diff line change
Expand Up @@ -743,13 +743,27 @@ static int printBuiltinTypes(MlirContext ctx) {
fprintf(stderr, "\n");
// CHECK: vector<2x3xf32>

// Scalable vector type.
bool scalable[] = {false, true};
MlirType scalableVector = mlirVectorTypeGetScalable(
sizeof(shape) / sizeof(int64_t), shape, scalable, f32);
if (!mlirTypeIsAVector(scalableVector))
return 16;
if (!mlirVectorTypeIsScalable(scalableVector) ||
mlirVectorTypeIsDimScalable(scalableVector, 0) ||
!mlirVectorTypeIsDimScalable(scalableVector, 1))
return 17;
mlirTypeDump(scalableVector);
fprintf(stderr, "\n");
// CHECK: vector<2x[3]xf32>

// Ranked tensor type.
MlirType rankedTensor = mlirRankedTensorTypeGet(
sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
if (!mlirTypeIsATensor(rankedTensor) ||
!mlirTypeIsARankedTensor(rankedTensor) ||
!mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
return 16;
return 18;
mlirTypeDump(rankedTensor);
fprintf(stderr, "\n");
// CHECK: tensor<2x3xf32>
Expand All @@ -759,7 +773,7 @@ static int printBuiltinTypes(MlirContext ctx) {
if (!mlirTypeIsATensor(unrankedTensor) ||
!mlirTypeIsAUnrankedTensor(unrankedTensor) ||
mlirShapedTypeHasRank(unrankedTensor))
return 17;
return 19;
mlirTypeDump(unrankedTensor);
fprintf(stderr, "\n");
// CHECK: tensor<*xf32>
Expand All @@ -770,7 +784,7 @@ static int printBuiltinTypes(MlirContext ctx) {
f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
if (!mlirTypeIsAMemRef(memRef) ||
!mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
return 18;
return 20;
mlirTypeDump(memRef);
fprintf(stderr, "\n");
// CHECK: memref<2x3xf32, 2>
Expand All @@ -782,7 +796,7 @@ static int printBuiltinTypes(MlirContext ctx) {
mlirTypeIsAMemRef(unrankedMemRef) ||
!mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
memSpace4))
return 19;
return 21;
mlirTypeDump(unrankedMemRef);
fprintf(stderr, "\n");
// CHECK: memref<*xf32, 4>
Expand All @@ -793,7 +807,7 @@ static int printBuiltinTypes(MlirContext ctx) {
if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
!mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
!mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
return 20;
return 22;
mlirTypeDump(tuple);
fprintf(stderr, "\n");
// CHECK: tuple<memref<*xf32, 4>, f32>
Expand All @@ -805,16 +819,16 @@ static int printBuiltinTypes(MlirContext ctx) {
mlirIntegerTypeGet(ctx, 64)};
MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
if (mlirFunctionTypeGetNumInputs(funcType) != 2)
return 21;
return 23;
if (mlirFunctionTypeGetNumResults(funcType) != 3)
return 22;
return 24;
if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
!mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
return 23;
return 25;
if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
!mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
!mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
return 24;
return 26;
mlirTypeDump(funcType);
fprintf(stderr, "\n");
// CHECK: (index, i1) -> (i16, i32, i64)
Expand All @@ -829,7 +843,7 @@ static int printBuiltinTypes(MlirContext ctx) {
!mlirStringRefEqual(mlirOpaqueTypeGetDialectNamespace(opaque),
namespace) ||
!mlirStringRefEqual(mlirOpaqueTypeGetData(opaque), data))
return 25;
return 27;
mlirTypeDump(opaque);
fprintf(stderr, "\n");
// CHECK: !dialect.type
Expand Down
43 changes: 41 additions & 2 deletions mlir/test/python/ir/builtin_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,14 +300,54 @@ def testVectorType():

none = NoneType.get()
try:
vector_invalid = VectorType.get(shape, none)
VectorType.get(shape, none)
except MLIRError as e:
# CHECK: Invalid type:
# CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
print(e)
else:
print("Exception not produced")

scalable_1 = VectorType.get(shape, f32, scalable=[False, True])
scalable_2 = VectorType.get([2, 3, 4], f32, scalable=[True, False, True])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base 10 is better than base 2 😃 i.e., it would be better if the python api just required the user to list the dims that are scalable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is consistent with how C++ constructors are implemented. I prefer consistency over niceness here, one can always add a wrapper.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a separate keyword for this, it's Python after all.

assert scalable_1.scalable
assert scalable_2.scalable
assert scalable_1.scalable_dims == [False, True]
assert scalable_2.scalable_dims == [True, False, True]
# CHECK: scalable 1: vector<2x[3]xf32>
print("scalable 1: ", scalable_1)
# CHECK: scalable 2: vector<[2]x3x[4]xf32>
print("scalable 2: ", scalable_2)

scalable_3 = VectorType.get(shape, f32, scalable_dims=[1])
scalable_4 = VectorType.get([2, 3, 4], f32, scalable_dims=[0, 2])
assert scalable_3 == scalable_1
assert scalable_4 == scalable_2

try:
VectorType.get(shape, f32, scalable=[False, True, True])
except ValueError as e:
# CHECK: Expected len(scalable) == len(shape).
print(e)
else:
print("Exception not produced")

try:
VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[1])
except ValueError as e:
# CHECK: kwargs are mutually exclusive.
print(e)
else:
print("Exception not produced")

try:
VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[42])
except ValueError as e:
# CHECK: Scalable dimension index out of bounds.
print(e)
else:
print("Exception not produced")


# CHECK-LABEL: TEST: testRankedTensorType
@run
Expand Down Expand Up @@ -337,7 +377,6 @@ def testRankedTensorType():
assert RankedTensorType.get(shape, f32).encoding is None



# CHECK-LABEL: TEST: testUnrankedTensorType
@run
def testUnrankedTensorType():
Expand Down