Skip to content

Commit a643bd3

Browse files
committed
[mlir] add permutation utility
I found myself typing this code several times at different places by now, so time to make this a general utility instead. Given a permutation, it returns the permuted position of the input, for example (i,j,k) -> (k,i,j) yields position 1 for input 0. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D108347
1 parent 194b080 commit a643bd3

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

mlir/include/mlir/IR/AffineMap.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ class AffineMap {
162162
/// when the caller knows it is safe to do so.
163163
unsigned getDimPosition(unsigned idx) const;
164164

165+
/// Extracts the permuted position where given input index resides.
166+
/// Fails when called on a non-permutation.
167+
unsigned getPermutedPosition(unsigned input) const;
168+
165169
/// Return true if any affine expression involves AffineDimExpr `position`.
166170
bool isFunctionOfDim(unsigned position) const {
167171
return llvm::any_of(getResults(), [&](AffineExpr e) {

mlir/lib/IR/AffineMap.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,14 @@ unsigned AffineMap::getDimPosition(unsigned idx) const {
336336
return getResult(idx).cast<AffineDimExpr>().getPosition();
337337
}
338338

339+
unsigned AffineMap::getPermutedPosition(unsigned input) const {
340+
assert(isPermutation() && "invalid permutation request");
341+
for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++)
342+
if (getDimPosition(i) == input)
343+
return i;
344+
llvm_unreachable("incorrect permutation request");
345+
}
346+
339347
/// Folds the results of the application of an affine map on the provided
340348
/// operands to a constant if possible. Returns false if the folding happens,
341349
/// true otherwise.

0 commit comments

Comments
 (0)