Skip to content

Commit c0f3ea8

Browse files
committed
[mlir][Python] Add checking process before create an AffineMap from a permutation.
An invalid permutation will trigger a C++ assertion when attempting to create an AffineMap from the permutation. This patch adds an `isPermutation` function to check the given permutation before creating the AffineMap. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D94492
1 parent 25b3921 commit c0f3ea8

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

mlir/lib/Bindings/Python/IRModules.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,21 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
153153
return mlirStringRefCreate(s.data(), s.size());
154154
}
155155

156+
template <typename PermutationTy>
157+
static bool isPermutation(std::vector<PermutationTy> permutation) {
158+
llvm::SmallVector<bool, 8> seen(permutation.size(), false);
159+
for (auto val : permutation) {
160+
if (val < permutation.size()) {
161+
if (seen[val])
162+
return false;
163+
seen[val] = true;
164+
continue;
165+
}
166+
return false;
167+
}
168+
return true;
169+
}
170+
156171
//------------------------------------------------------------------------------
157172
// Collections.
158173
//------------------------------------------------------------------------------
@@ -3914,6 +3929,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
39143929
"get_permutation",
39153930
[](std::vector<unsigned> permutation,
39163931
DefaultingPyMlirContext context) {
3932+
if (!isPermutation(permutation))
3933+
throw py::cast_error("Invalid permutation when attempting to "
3934+
"create an AffineMap");
39173935
MlirAffineMap affineMap = mlirAffineMapPermutationGet(
39183936
context->get(), permutation.size(), permutation.data());
39193937
return PyAffineMap(context->getRef(), affineMap);

mlir/test/Bindings/Python/ir_affine_map.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ def testAffineMapGet():
7373
# CHECK: Invalid expression (None?) when attempting to create an AffineMap
7474
print(e)
7575

76+
try:
77+
AffineMap.get_permutation([1, 0, 1])
78+
except RuntimeError as e:
79+
# CHECK: Invalid permutation when attempting to create an AffineMap
80+
print(e)
81+
7682
try:
7783
map3.get_submap([42])
7884
except ValueError as e:

0 commit comments

Comments
 (0)