18
18
#include " mlir/Support/LLVM.h"
19
19
#include " llvm/ADT/ArrayRef.h"
20
20
#include " llvm/ADT/SmallVector.h"
21
+ #include " llvm/ADT/iterator.h"
21
22
#include < optional>
23
+ #include < utility>
22
24
23
25
namespace mlir {
24
26
class ArrayAttr ;
@@ -195,6 +197,23 @@ SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
195
197
// Permutation utils.
196
198
// ===----------------------------------------------------------------------===//
197
199
200
+ template <typename T>
201
+ SmallVector<T> applyPermutation (ArrayRef<T> input,
202
+ ArrayRef<int64_t > permutation) {
203
+ assert (input.size () == permutation.size () &&
204
+ " expected input rank to equal permutation rank" );
205
+ auto permutationRange = llvm::map_range (
206
+ llvm::seq<unsigned >(0 , input.size ()),
207
+ [&](int64_t idx) -> T { return input[permutation[idx]]; });
208
+ return llvm::to_vector (permutationRange);
209
+ }
210
+
211
+ template <typename T>
212
+ SmallVector<T> applyPermutation (const SmallVectorImpl<T> &input,
213
+ ArrayRef<int64_t > permutation) {
214
+ return applyPermutation (ArrayRef (input), permutation);
215
+ }
216
+
198
217
// / Apply the permutation defined by `permutation` to `inVec`.
199
218
// / Element `i` in `inVec` is mapped to location `j = permutation[i]`.
200
219
// / E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation
@@ -203,10 +222,7 @@ SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
203
222
template <typename T, unsigned N>
204
223
void applyPermutationToVector (SmallVector<T, N> &inVec,
205
224
ArrayRef<int64_t > permutation) {
206
- SmallVector<T, N> auxVec (inVec.size ());
207
- for (const auto &en : enumerate(permutation))
208
- auxVec[en.index ()] = inVec[en.value ()];
209
- inVec = auxVec;
225
+ inVec = applyPermutation (inVec, permutation);
210
226
}
211
227
212
228
// / Helper method to apply to inverse a permutation.
@@ -239,6 +255,138 @@ std::pair<AffineExpr, SmallVector<OpFoldResult>>
239
255
computeLinearIndex (OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
240
256
ArrayRef<OpFoldResult> indices);
241
257
258
+ // ===----------------------------------------------------------------------===//
259
+ // Utilities for decomposing larger shapes
260
+ // ===----------------------------------------------------------------------===//
261
+
262
+ namespace detail {
263
+ // / Encapsulates the set of parameters that are used to make tile offset
264
+ // / calculations in the TileOffsetRangeIterator.
265
+ class TileOffsetRangeImpl {
266
+ public:
267
+ TileOffsetRangeImpl (ArrayRef<int64_t > shape, ArrayRef<int64_t > tileShape,
268
+ ArrayRef<int64_t > loopOrder);
269
+
270
+ int64_t getMaxLinearIndex () const { return maxLinearIndex; }
271
+
272
+ SmallVector<int64_t > getStaticTileOffsets (int64_t linearIndex) const ;
273
+
274
+ SmallVector<AffineExpr> getDynamicTileOffsets (AffineExpr linearIndex) const ;
275
+
276
+ template <typename T>
277
+ SmallVector<T> getTileOffsets (T linearIndex) const {
278
+ if constexpr (std::is_same_v<T, int64_t >)
279
+ return getStaticTileOffsets (linearIndex);
280
+ else
281
+ return getDynamicTileOffsets (linearIndex);
282
+ }
283
+
284
+ private:
285
+ // / The sub-shape that divides the larger outer shape (which is provided to
286
+ // / the constructor).
287
+ SmallVector<int64_t > tileShape;
288
+ // / The inverse permutation to the `loopOrder` permutation provided in the
289
+ // / constructor.
290
+ SmallVector<int64_t > inverseLoopOrder;
291
+ // / The strides for the basis 'div(shape, tileShape)' permuted by `loopOrder`.
292
+ SmallVector<int64_t > sliceStrides;
293
+ // / The maximum linear index in the iteration space given by basis 'div(shape,
294
+ // / tileShape)'.
295
+ int64_t maxLinearIndex;
296
+ };
297
+
298
+ // / The STL-style iterator implementation for StaticTileOffsetRange.
299
+ template <typename ElementType>
300
+ class TileOffsetRangeIterator
301
+ : public llvm::iterator_facade_base<TileOffsetRangeIterator<ElementType>,
302
+ std::forward_iterator_tag,
303
+ SmallVector<ElementType>> {
304
+ public:
305
+ TileOffsetRangeIterator (const TileOffsetRangeImpl ¶ms, ElementType index)
306
+ : params(params), index(index) {}
307
+
308
+ void operator ++() { incrementIndex (1 ); }
309
+ TileOffsetRangeIterator operator ++(int ) {
310
+ const auto copy = *this ;
311
+ ++*this ;
312
+ return copy;
313
+ }
314
+
315
+ bool operator ==(const TileOffsetRangeIterator &other) const {
316
+ return index == other.index ;
317
+ }
318
+ bool operator !=(const TileOffsetRangeIterator &other) const {
319
+ return index != other.index ;
320
+ }
321
+
322
+ SmallVector<ElementType> operator *() const {
323
+ return params.getTileOffsets (index);
324
+ }
325
+ void operator +=(int64_t offset) { incrementIndex (offset); }
326
+
327
+ private:
328
+ void incrementIndex (int64_t offset) { index = index + offset; }
329
+ const TileOffsetRangeImpl params;
330
+ int64_t index;
331
+ };
332
+ } // namespace detail
333
+
334
+ // / A range-style iterator that allows for iterating over the offsets of all
335
+ // / potential tiles of size `tileShape` within the larger shape `shape`, using
336
+ // / an ordering specified by `loopOrder`. The `loopOrder` specifies the order of
337
+ // / unrolling by numbering the dimensions in order from "outer most for loop"
338
+ // / (slowest changing) to "inner most for loop" (fastest changing).
339
+ // /
340
+ // / For example, for `shape = {10, 20, 30}`, `tileShape = {5, 10, 15}`, and
341
+ // / `loopOrder={2, 0, 1}`, the iterating over this range will yield offsets:
342
+ // /
343
+ // / ```
344
+ // / {0, 0, 0}, {0, 10, 0}, {5, 0, 0}, {5, 10, 0}, {0, 0, 15},
345
+ // / {0, 10, 15}, {5, 0, 15}, {0, 10, 15}, {5, 10, 15}
346
+ // / ```
347
+ // /
348
+ // / This is useful in contexts where a vector computation over a larger shape
349
+ // / needs to be unrolled to a set of operations on subsets of the original
350
+ // / operands, such as during the "vector unrolling" transformations.
351
+ // /
352
+ // / The size of `tileShape` must be less-than-or-equal-to the size of `shape`.a
353
+ // / If the rank of `tileShape` is smaller than `shape`, then `tileShape`
354
+ // / elements correspond to the trailing dimensions of `shape`, and the leading
355
+ // / dimensions are considered untiled and `tileShape` is effectively prepended
356
+ // / with the leading dims of `shape`.
357
+ class StaticTileOffsetRange {
358
+ public:
359
+ using IteratorTy = detail::TileOffsetRangeIterator<int64_t >;
360
+ using ParamsTy = detail::TileOffsetRangeImpl;
361
+
362
+ StaticTileOffsetRange (ArrayRef<int64_t > shape, ArrayRef<int64_t > tileShape,
363
+ ArrayRef<int64_t > loopOrder)
364
+ : params(shape, tileShape, loopOrder), beginValue(params, 0 ),
365
+ pastEndValue (params, params.getMaxLinearIndex()) {
366
+ assert (shape.size () >= tileShape.size ());
367
+ assert (loopOrder.size () == shape.size ());
368
+ }
369
+
370
+ // / Create the range with identity loop order.
371
+ StaticTileOffsetRange (ArrayRef<int64_t > shape, ArrayRef<int64_t > tileShape)
372
+ : params(shape, tileShape,
373
+ llvm::to_vector (llvm::seq<int64_t >(0 , shape.size()))),
374
+ beginValue(params, 0 ),
375
+ pastEndValue(params, params.getMaxLinearIndex()) {
376
+ assert (shape.size () >= tileShape.size ());
377
+ }
378
+
379
+ IteratorTy begin () const { return beginValue; }
380
+ IteratorTy end () const { return pastEndValue; }
381
+
382
+ // / Returns the total number of tiles that fit in the larger shape.
383
+ size_t size () const { return params.getMaxLinearIndex (); }
384
+
385
+ private:
386
+ const ParamsTy params;
387
+ IteratorTy beginValue;
388
+ IteratorTy pastEndValue;
389
+ };
242
390
} // namespace mlir
243
391
244
392
#endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H
0 commit comments