Skip to content

Commit 831041b

Browse files
[mlir][vector] Cleanup VectorUnroll and create a generic tile iteration utility
This change refactors some of the utilities used to unroll larger vector computations into smaller vector computations. In fact, the indexing computations used here are rather generic and are useful in other dialects or downstream projects. Therefore, a utility for iterating over all possible tile offsets for a particular pair of static (shape, tiled shape) is introduced in IndexingUtils and replaces the existing computations in the vector unrolling transformations. This builds off of the refactoring of IndexingUtils introduced in 203fad4. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D150000
1 parent ed4daea commit 831041b

File tree

7 files changed

+311
-133
lines changed

7 files changed

+311
-133
lines changed

mlir/include/mlir/Dialect/Utils/IndexingUtils.h

Lines changed: 152 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
#include "mlir/Support/LLVM.h"
1919
#include "llvm/ADT/ArrayRef.h"
2020
#include "llvm/ADT/SmallVector.h"
21+
#include "llvm/ADT/iterator.h"
2122
#include <optional>
23+
#include <utility>
2224

2325
namespace mlir {
2426
class ArrayAttr;
@@ -195,6 +197,23 @@ SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
195197
// Permutation utils.
196198
//===----------------------------------------------------------------------===//
197199

200+
template <typename T>
201+
SmallVector<T> applyPermutation(ArrayRef<T> input,
202+
ArrayRef<int64_t> permutation) {
203+
assert(input.size() == permutation.size() &&
204+
"expected input rank to equal permutation rank");
205+
auto permutationRange = llvm::map_range(
206+
llvm::seq<unsigned>(0, input.size()),
207+
[&](int64_t idx) -> T { return input[permutation[idx]]; });
208+
return llvm::to_vector(permutationRange);
209+
}
210+
211+
template <typename T>
212+
SmallVector<T> applyPermutation(const SmallVectorImpl<T> &input,
213+
ArrayRef<int64_t> permutation) {
214+
return applyPermutation(ArrayRef(input), permutation);
215+
}
216+
198217
/// Apply the permutation defined by `permutation` to `inVec`.
199218
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
200219
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation
@@ -203,10 +222,7 @@ SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
203222
template <typename T, unsigned N>
204223
void applyPermutationToVector(SmallVector<T, N> &inVec,
205224
ArrayRef<int64_t> permutation) {
206-
SmallVector<T, N> auxVec(inVec.size());
207-
for (const auto &en : enumerate(permutation))
208-
auxVec[en.index()] = inVec[en.value()];
209-
inVec = auxVec;
225+
inVec = applyPermutation(inVec, permutation);
210226
}
211227

212228
/// Helper method to apply to inverse a permutation.
@@ -239,6 +255,138 @@ std::pair<AffineExpr, SmallVector<OpFoldResult>>
239255
computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
240256
ArrayRef<OpFoldResult> indices);
241257

258+
//===----------------------------------------------------------------------===//
259+
// Utilities for decomposing larger shapes
260+
//===----------------------------------------------------------------------===//
261+
262+
namespace detail {
263+
/// Encapsulates the set of parameters that are used to make tile offset
264+
/// calculations in the TileOffsetRangeIterator.
265+
class TileOffsetRangeImpl {
266+
public:
267+
TileOffsetRangeImpl(ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape,
268+
ArrayRef<int64_t> loopOrder);
269+
270+
int64_t getMaxLinearIndex() const { return maxLinearIndex; }
271+
272+
SmallVector<int64_t> getStaticTileOffsets(int64_t linearIndex) const;
273+
274+
SmallVector<AffineExpr> getDynamicTileOffsets(AffineExpr linearIndex) const;
275+
276+
template <typename T>
277+
SmallVector<T> getTileOffsets(T linearIndex) const {
278+
if constexpr (std::is_same_v<T, int64_t>)
279+
return getStaticTileOffsets(linearIndex);
280+
else
281+
return getDynamicTileOffsets(linearIndex);
282+
}
283+
284+
private:
285+
/// The sub-shape that divides the larger outer shape (which is provided to
286+
/// the constructor).
287+
SmallVector<int64_t> tileShape;
288+
/// The inverse permutation to the `loopOrder` permutation provided in the
289+
/// constructor.
290+
SmallVector<int64_t> inverseLoopOrder;
291+
/// The strides for the basis 'div(shape, tileShape)' permuted by `loopOrder`.
292+
SmallVector<int64_t> sliceStrides;
293+
/// The maximum linear index in the iteration space given by basis 'div(shape,
294+
/// tileShape)'.
295+
int64_t maxLinearIndex;
296+
};
297+
298+
/// The STL-style iterator implementation for StaticTileOffsetRange.
299+
template <typename ElementType>
300+
class TileOffsetRangeIterator
301+
: public llvm::iterator_facade_base<TileOffsetRangeIterator<ElementType>,
302+
std::forward_iterator_tag,
303+
SmallVector<ElementType>> {
304+
public:
305+
TileOffsetRangeIterator(const TileOffsetRangeImpl &params, ElementType index)
306+
: params(params), index(index) {}
307+
308+
void operator++() { incrementIndex(1); }
309+
TileOffsetRangeIterator operator++(int) {
310+
const auto copy = *this;
311+
++*this;
312+
return copy;
313+
}
314+
315+
bool operator==(const TileOffsetRangeIterator &other) const {
316+
return index == other.index;
317+
}
318+
bool operator!=(const TileOffsetRangeIterator &other) const {
319+
return index != other.index;
320+
}
321+
322+
SmallVector<ElementType> operator*() const {
323+
return params.getTileOffsets(index);
324+
}
325+
void operator+=(int64_t offset) { incrementIndex(offset); }
326+
327+
private:
328+
void incrementIndex(int64_t offset) { index = index + offset; }
329+
const TileOffsetRangeImpl params;
330+
int64_t index;
331+
};
332+
} // namespace detail
333+
334+
/// A range-style iterator that allows for iterating over the offsets of all
335+
/// potential tiles of size `tileShape` within the larger shape `shape`, using
336+
/// an ordering specified by `loopOrder`. The `loopOrder` specifies the order of
337+
/// unrolling by numbering the dimensions in order from "outer most for loop"
338+
/// (slowest changing) to "inner most for loop" (fastest changing).
339+
///
340+
/// For example, for `shape = {10, 20, 30}`, `tileShape = {5, 10, 15}`, and
341+
/// `loopOrder={2, 0, 1}`, the iterating over this range will yield offsets:
342+
///
343+
/// ```
344+
/// {0, 0, 0}, {0, 10, 0}, {5, 0, 0}, {5, 10, 0}, {0, 0, 15},
345+
/// {0, 10, 15}, {5, 0, 15}, {0, 10, 15}, {5, 10, 15}
346+
/// ```
347+
///
348+
/// This is useful in contexts where a vector computation over a larger shape
349+
/// needs to be unrolled to a set of operations on subsets of the original
350+
/// operands, such as during the "vector unrolling" transformations.
351+
///
352+
/// The size of `tileShape` must be less-than-or-equal-to the size of `shape`.a
353+
/// If the rank of `tileShape` is smaller than `shape`, then `tileShape`
354+
/// elements correspond to the trailing dimensions of `shape`, and the leading
355+
/// dimensions are considered untiled and `tileShape` is effectively prepended
356+
/// with the leading dims of `shape`.
357+
class StaticTileOffsetRange {
358+
public:
359+
using IteratorTy = detail::TileOffsetRangeIterator<int64_t>;
360+
using ParamsTy = detail::TileOffsetRangeImpl;
361+
362+
StaticTileOffsetRange(ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape,
363+
ArrayRef<int64_t> loopOrder)
364+
: params(shape, tileShape, loopOrder), beginValue(params, 0),
365+
pastEndValue(params, params.getMaxLinearIndex()) {
366+
assert(shape.size() >= tileShape.size());
367+
assert(loopOrder.size() == shape.size());
368+
}
369+
370+
/// Create the range with identity loop order.
371+
StaticTileOffsetRange(ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape)
372+
: params(shape, tileShape,
373+
llvm::to_vector(llvm::seq<int64_t>(0, shape.size()))),
374+
beginValue(params, 0),
375+
pastEndValue(params, params.getMaxLinearIndex()) {
376+
assert(shape.size() >= tileShape.size());
377+
}
378+
379+
IteratorTy begin() const { return beginValue; }
380+
IteratorTy end() const { return pastEndValue; }
381+
382+
/// Returns the total number of tiles that fit in the larger shape.
383+
size_t size() const { return params.getMaxLinearIndex(); }
384+
385+
private:
386+
const ParamsTy params;
387+
IteratorTy beginValue;
388+
IteratorTy pastEndValue;
389+
};
242390
} // namespace mlir
243391

244392
#endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H

mlir/include/mlir/IR/AffineExpr.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Support/LLVM.h"
1818
#include "llvm/ADT/DenseMapInfo.h"
1919
#include "llvm/ADT/Hashing.h"
20+
#include "llvm/ADT/SmallVector.h"
2021
#include "llvm/Support/Casting.h"
2122
#include <functional>
2223
#include <type_traits>
@@ -250,6 +251,8 @@ inline AffineExpr operator-(int64_t val, AffineExpr expr) {
250251
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context);
251252
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context);
252253
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context);
254+
SmallVector<AffineExpr> getAffineConstantExprs(ArrayRef<int64_t> constants,
255+
MLIRContext *context);
253256
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
254257
AffineExpr rhs);
255258

mlir/lib/Dialect/Utils/IndexingUtils.cpp

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,8 @@ AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
181181

182182
AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
183183
ArrayRef<int64_t> basis) {
184-
SmallVector<AffineExpr> basisExprs = llvm::to_vector(llvm::map_range(
185-
basis, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); }));
186-
return linearize(ctx, offsets, basisExprs);
184+
185+
return linearize(ctx, offsets, getAffineConstantExprs(basis, ctx));
187186
}
188187

189188
SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
@@ -196,9 +195,7 @@ SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
196195
SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
197196
ArrayRef<int64_t> strides) {
198197
MLIRContext *ctx = linearIndex.getContext();
199-
SmallVector<AffineExpr> basisExprs = llvm::to_vector(llvm::map_range(
200-
strides, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); }));
201-
return delinearize(linearIndex, ArrayRef<AffineExpr>{basisExprs});
198+
return delinearize(linearIndex, getAffineConstantExprs(strides, ctx));
202199
}
203200

204201
//===----------------------------------------------------------------------===//
@@ -302,3 +299,56 @@ mlir::computeLinearIndex(OpFoldResult sourceOffset,
302299

303300
return {expr, values};
304301
}
302+
303+
//===----------------------------------------------------------------------===//
304+
// TileOffsetRange
305+
//===----------------------------------------------------------------------===//
306+
307+
/// Apply left-padding by 1 to the tile shape if required.
308+
static SmallVector<int64_t> padTileShapeToSize(ArrayRef<int64_t> tileShape,
309+
unsigned paddedSize) {
310+
assert(tileShape.size() <= paddedSize &&
311+
"expected tileShape to <= paddedSize");
312+
if (tileShape.size() == paddedSize)
313+
return to_vector(tileShape);
314+
SmallVector<int64_t> result(paddedSize - tileShape.size(), 1);
315+
llvm::append_range(result, tileShape);
316+
return result;
317+
}
318+
319+
mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl(
320+
ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape,
321+
ArrayRef<int64_t> loopOrder)
322+
: tileShape(padTileShapeToSize(tileShape, shape.size())),
323+
inverseLoopOrder(invertPermutationVector(loopOrder)),
324+
sliceStrides(shape.size()) {
325+
// Divide the shape by the tile shape.
326+
std::optional<SmallVector<int64_t>> shapeRatio =
327+
mlir::computeShapeRatio(shape, tileShape);
328+
assert(shapeRatio && shapeRatio->size() == shape.size() &&
329+
"target shape does not evenly divide the original shape");
330+
assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() &&
331+
"expected loop order to be a permutation of rank equal to outer "
332+
"shape");
333+
334+
maxLinearIndex = mlir::computeMaxLinearIndex(*shapeRatio);
335+
mlir::applyPermutationToVector(*shapeRatio, loopOrder);
336+
sliceStrides = mlir::computeStrides(*shapeRatio);
337+
}
338+
339+
SmallVector<int64_t> mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets(
340+
int64_t linearIndex) const {
341+
SmallVector<int64_t> tileCoords = applyPermutation(
342+
delinearize(linearIndex, sliceStrides), inverseLoopOrder);
343+
return computeElementwiseMul(tileCoords, tileShape);
344+
}
345+
346+
SmallVector<AffineExpr>
347+
mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets(
348+
AffineExpr linearIndex) const {
349+
MLIRContext *ctx = linearIndex.getContext();
350+
SmallVector<AffineExpr> tileCoords = applyPermutation(
351+
delinearize(linearIndex, sliceStrides), inverseLoopOrder);
352+
return mlir::computeElementwiseMul(tileCoords,
353+
getAffineConstantExprs(tileShape, ctx));
354+
}

0 commit comments

Comments
 (0)