15
15
#include " mlir/Dialect/Vector/Transforms/VectorTransforms.h"
16
16
#include " mlir/IR/ImplicitLocOpBuilder.h"
17
17
#include " mlir/Interfaces/VectorInterfaces.h"
18
+ #include " mlir/Support/MathExtras.h"
18
19
#include " llvm/ADT/MapVector.h"
19
- #include " llvm/Support/Debug.h"
20
+ #include " llvm/ADT/STLExtras.h"
21
+ #include < numeric>
20
22
21
23
#define DEBUG_TYPE " vector-unrolling"
22
24
@@ -36,20 +38,81 @@ static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
36
38
return elementOffsets;
37
39
}
38
40
41
+ // / A functor that accomplishes the same thing as `getVectorOffset` but allows
42
+ // / for reordering the traversal of the dimensions. The order of traversal is
43
+ // / given in "for loop order" (outer to inner).
44
+ namespace {
45
+ class DecomposeShapeIterator {
46
+ private:
47
+ SmallVector<int64_t , 4 > vectorShape;
48
+ SmallVector<int64_t > loopOrder;
49
+ SmallVector<int64_t > sliceStrides;
50
+ int64_t maxIndexVal{1 };
51
+
52
+ public:
53
+ DecomposeShapeIterator (ArrayRef<int64_t > originalShape,
54
+ ArrayRef<int64_t > targetShape,
55
+ ArrayRef<int64_t > loopOrder)
56
+ : vectorShape(targetShape.begin(), targetShape.end()),
57
+ loopOrder (loopOrder.begin(), loopOrder.end()),
58
+ sliceStrides(originalShape.size()) {
59
+ assert (originalShape.size () == targetShape.size ());
60
+ assert (loopOrder.size () == targetShape.size ());
61
+
62
+ // Compute the count for each dimension.
63
+ SmallVector<int64_t > sliceDimCounts (originalShape.size ());
64
+ for (unsigned r = 0 ; r < originalShape.size (); ++r) {
65
+ sliceDimCounts[r] = ceilDiv (originalShape[r], targetShape[r]);
66
+ maxIndexVal *= sliceDimCounts[r];
67
+ }
68
+
69
+ // Reversing "loop order" gives dimensions from fastest varying to slowest
70
+ // varying (smallest stride to largest stride).
71
+ int64_t accum = 1 ;
72
+ for (auto idx : llvm::reverse (loopOrder)) {
73
+ sliceStrides[idx] = accum;
74
+ accum *= sliceDimCounts[idx];
75
+ }
76
+ }
77
+
78
+ // Turn the linear index into a d-tuple based on units of vectors of size
79
+ // `vectorShape`. The linear index is assumed to represent traversal of the
80
+ // dimensions based on `order`.
81
+ SmallVector<int64_t > delinearize (int64_t index) const {
82
+ // Traverse in for loop order (largest stride to smallest stride).
83
+ SmallVector<int64_t > vectorOffsets (sliceStrides.size ());
84
+ for (auto idx : loopOrder) {
85
+ vectorOffsets[idx] = index / sliceStrides[idx];
86
+ index %= sliceStrides[idx];
87
+ }
88
+ return vectorOffsets;
89
+ }
90
+
91
+ int64_t maxIndex () const { return maxIndexVal; }
92
+
93
+ // / Return the offset within d-tuple based on the ordering given by
94
+ // / `loopOrder`.
95
+ SmallVector<int64_t > getVectorOffset (int64_t index) const {
96
+ SmallVector<int64_t > vectorOffsets = delinearize (index);
97
+ SmallVector<int64_t > elementOffsets =
98
+ computeElementOffsetsFromVectorSliceOffsets (vectorShape, vectorOffsets);
99
+ return elementOffsets;
100
+ }
101
+ };
102
+ } // namespace
103
+
39
104
// / Compute the indices of the slice `index` for a tranfer op.
40
- static SmallVector<Value>
41
- sliceTransferIndices ( int64_t index, ArrayRef<int64_t > originalShape ,
42
- ArrayRef< int64_t > targetShape, ArrayRef<Value> indices ,
43
- AffineMap permutationMap, Location loc,
44
- OpBuilder &builder) {
105
+ static SmallVector<Value> sliceTransferIndices (ArrayRef< int64_t > elementOffsets,
106
+ ArrayRef<Value> indices ,
107
+ AffineMap permutationMap ,
108
+ Location loc,
109
+ OpBuilder &builder) {
45
110
MLIRContext *ctx = builder.getContext ();
46
111
auto isBroadcast = [](AffineExpr expr) {
47
112
if (auto constExpr = expr.dyn_cast <AffineConstantExpr>())
48
113
return constExpr.getValue () == 0 ;
49
114
return false ;
50
115
};
51
- SmallVector<int64_t , 4 > elementOffsets =
52
- getVectorOffset (originalShape, targetShape, index);
53
116
// Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
54
117
SmallVector<Value> slicedIndices (indices.begin (), indices.end ());
55
118
for (const auto &dim : llvm::enumerate (permutationMap.getResults ())) {
@@ -99,6 +162,20 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
99
162
return targetShape;
100
163
}
101
164
165
+ static SmallVector<int64_t >
166
+ getUnrollOrder (unsigned numLoops, Operation *op,
167
+ const vector::UnrollVectorOptions &options) {
168
+ SmallVector<int64_t > loopOrder =
169
+ llvm::to_vector (llvm::seq<int64_t >(0 , static_cast <int64_t >(numLoops)));
170
+ if (options.traversalOrderCallback != nullptr ) {
171
+ Optional<SmallVector<int64_t >> order = options.traversalOrderCallback (op);
172
+ if (order.hasValue ()) {
173
+ loopOrder = std::move (*order);
174
+ }
175
+ }
176
+ return loopOrder;
177
+ }
178
+
102
179
namespace {
103
180
104
181
struct UnrollTransferReadPattern
@@ -121,27 +198,30 @@ struct UnrollTransferReadPattern
121
198
SmallVector<int64_t , 4 > strides (targetShape->size (), 1 );
122
199
Location loc = readOp.getLoc ();
123
200
ArrayRef<int64_t > originalSize = readOp.getVectorType ().getShape ();
124
- SmallVector<int64_t , 4 > ratio = *shapeRatio (originalSize, *targetShape);
125
- // Compute shape ratio of 'shape' and 'sizes'.
126
- int64_t sliceCount = computeMaxLinearIndex (ratio);
201
+
127
202
// Prepare the result vector;
128
203
Value result = rewriter.create <arith::ConstantOp>(
129
204
loc, sourceVectorType, rewriter.getZeroAttr (sourceVectorType));
130
205
auto targetType =
131
206
VectorType::get (*targetShape, sourceVectorType.getElementType ());
132
207
SmallVector<Value, 4 > originalIndices (readOp.getIndices ().begin (),
133
208
readOp.getIndices ().end ());
134
- for (int64_t i = 0 ; i < sliceCount; i++) {
209
+
210
+ SmallVector<int64_t > loopOrder =
211
+ getUnrollOrder (originalSize.size (), readOp, options);
212
+ DecomposeShapeIterator indexToOffsets (originalSize, *targetShape,
213
+ loopOrder);
214
+ for (int64_t i = 0 ; i < indexToOffsets.maxIndex (); i++) {
215
+ SmallVector<int64_t , 4 > elementOffsets =
216
+ indexToOffsets.getVectorOffset (i);
135
217
SmallVector<Value, 4 > indices =
136
- sliceTransferIndices (i, originalSize, *targetShape , originalIndices,
218
+ sliceTransferIndices (elementOffsets , originalIndices,
137
219
readOp.getPermutationMap (), loc, rewriter);
138
220
auto slicedRead = rewriter.create <vector::TransferReadOp>(
139
221
loc, targetType, readOp.getSource (), indices,
140
222
readOp.getPermutationMapAttr (), readOp.getPadding (), readOp.getMask (),
141
223
readOp.getInBoundsAttr ());
142
224
143
- SmallVector<int64_t , 4 > elementOffsets =
144
- getVectorOffset (originalSize, *targetShape, i);
145
225
result = rewriter.create <vector::InsertStridedSliceOp>(
146
226
loc, slicedRead, result, elementOffsets, strides);
147
227
}
@@ -174,20 +254,21 @@ struct UnrollTransferWritePattern
174
254
SmallVector<int64_t , 4 > strides (targetShape->size (), 1 );
175
255
Location loc = writeOp.getLoc ();
176
256
ArrayRef<int64_t > originalSize = sourceVectorType.getShape ();
177
- SmallVector<int64_t , 4 > ratio = *shapeRatio (originalSize, *targetShape);
178
- // Compute shape ratio of 'shape' and 'sizes'.
179
- int64_t sliceCount = computeMaxLinearIndex (ratio);
180
257
SmallVector<Value, 4 > originalIndices (writeOp.getIndices ().begin (),
181
258
writeOp.getIndices ().end ());
259
+
260
+ SmallVector<int64_t > loopOrder =
261
+ getUnrollOrder (originalSize.size (), writeOp, options);
262
+ DecomposeShapeIterator indexToOffsets (originalSize, *targetShape,
263
+ loopOrder);
182
264
Value resultTensor;
183
- for (int64_t i = 0 ; i < sliceCount ; i++) {
265
+ for (int64_t i = 0 ; i < indexToOffsets. maxIndex () ; i++) {
184
266
SmallVector<int64_t , 4 > elementOffsets =
185
- getVectorOffset (originalSize, *targetShape, i);
267
+ indexToOffsets. getVectorOffset (i);
186
268
Value slicedVector = rewriter.create <vector::ExtractStridedSliceOp>(
187
269
loc, writeOp.getVector (), elementOffsets, *targetShape, strides);
188
-
189
270
SmallVector<Value, 4 > indices =
190
- sliceTransferIndices (i, originalSize, *targetShape , originalIndices,
271
+ sliceTransferIndices (elementOffsets , originalIndices,
191
272
writeOp.getPermutationMap (), loc, rewriter);
192
273
Operation *slicedWrite = rewriter.create <vector::TransferWriteOp>(
193
274
loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource (),
@@ -236,20 +317,22 @@ struct UnrollContractionPattern
236
317
return failure ();
237
318
auto dstVecType = contractOp.getResultType ().cast <VectorType>();
238
319
SmallVector<int64_t , 4 > originalSize = *contractOp.getShapeForUnroll ();
239
- SmallVector<int64_t , 4 > ratio = *shapeRatio (originalSize, *targetShape);
240
320
241
- // Compute shape ratio of 'shape' and 'sizes'.
242
- int64_t sliceCount = computeMaxLinearIndex (ratio);
243
321
Location loc = contractOp.getLoc ();
244
322
unsigned accIndex = vector::ContractionOp::getAccOperandIndex ();
245
323
AffineMap dstAffineMap = contractOp.getIndexingMaps ()[accIndex];
246
324
llvm::MapVector<
247
325
SmallVector<int64_t >, Value,
248
326
llvm::DenseMap<SmallVector<int64_t >, unsigned , OffsetMapInfo>>
249
327
accCache;
328
+
329
+ SmallVector<int64_t > loopOrder = getUnrollOrder (
330
+ contractOp.getIteratorTypes ().size (), contractOp, options);
331
+ DecomposeShapeIterator indexToOffsets (originalSize, *targetShape,
332
+ loopOrder);
333
+ const int64_t sliceCount = indexToOffsets.maxIndex ();
250
334
for (int64_t i = 0 ; i < sliceCount; i++) {
251
- SmallVector<int64_t , 4 > offsets =
252
- getVectorOffset (originalSize, *targetShape, i);
335
+ SmallVector<int64_t , 4 > offsets = indexToOffsets.getVectorOffset (i);
253
336
SmallVector<Value, 4 > slicesOperands (contractOp.getNumOperands ());
254
337
255
338
// Helper to coompute the new shape of each operand and extract the slice.
0 commit comments