Skip to content

Commit afba555

Browse files
committed
[mlir][vector] Add actualRank output parameter to createUnrollIterator()
This provides an easy way of finding the actual rank the vector type will/can be unrolled to (which may be > the `targetRank`).
1 parent bc946f5 commit afba555

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,24 @@ bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
8686
///
8787
/// If no leading dimensions can be unrolled an empty optional will be returned.
8888
///
89+
/// The actual rank the vector type can be unrolled to can be discovered by
90+
/// passing a pointer (to an int64_t) to the optional `actualRank` parameter.
91+
///
8992
/// Examples:
9093
///
9194
/// For vType = vector<2x3x4> and targetRank = 1
9295
///
9396
/// The resulting iterator will yield:
94-
/// [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
97+
/// [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2] (actualRank = 1)
9598
///
9699
/// For vType = vector<3x[4]x5> and targetRank = 0
97100
///
98101
/// The scalable dimension blocks unrolling so the iterator yields only:
99-
/// [0], [1], [2]
102+
/// [0], [1], [2] (actualRank = 2)
100103
///
101104
std::optional<StaticTileOffsetRange>
102-
createUnrollIterator(VectorType vType, int64_t targetRank = 1);
105+
createUnrollIterator(VectorType vType, int64_t targetRank = 1,
106+
int64_t *actualRank = nullptr);
103107

104108
/// A wrapper for getMixedSizes for vector.transfer_read and
105109
/// vector.transfer_write Ops (for source and destination, respectively).

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,24 +285,35 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
285285
}
286286

287287
std::optional<StaticTileOffsetRange>
288-
vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
289-
if (vType.getRank() <= targetRank)
288+
vector::createUnrollIterator(VectorType vType, int64_t targetRank,
289+
int64_t *actualRank) {
290+
auto reportActualRank = [&](int64_t rank) {
291+
if (actualRank)
292+
*actualRank = rank;
293+
};
294+
auto vectorRank = vType.getRank();
295+
if (vectorRank <= targetRank) {
296+
reportActualRank(vectorRank);
290297
return {};
298+
}
291299
// Attempt to unroll until targetRank or the first scalable dimension (which
292300
// cannot be unrolled).
293301
auto shapeToUnroll = vType.getShape().drop_back(targetRank);
294302
auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
295303
auto it =
296304
std::find(scalableDimsToUnroll.begin(), scalableDimsToUnroll.end(), true);
297305
auto firstScalableDim = it - scalableDimsToUnroll.begin();
298-
if (firstScalableDim == 0)
306+
if (firstScalableDim == 0) {
307+
reportActualRank(vectorRank);
299308
return {};
309+
}
300310
// All scalable dimensions should be removed now.
301311
scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
302312
assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
303313
"unexpected leading scalable dimension");
304314
// Create an unroll iterator for leading dimensions.
305315
shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
316+
reportActualRank(vectorRank - shapeToUnroll.size());
306317
return StaticTileOffsetRange(shapeToUnroll, /*unrollStep=*/1);
307318
}
308319

0 commit comments

Comments
 (0)