10
10
11
11
#include " mlir/IR/AffineMap.h"
12
12
#include " mlir/IR/Builders.h"
13
+ #include " mlir/IR/BuiltinTypeInterfaces.h"
14
+ #include " llvm/ADT/ArrayRef.h"
15
+ #include " llvm/ADT/SmallVector.h"
16
+ #include " llvm/Support/LogicalResult.h"
13
17
14
18
#include < numeric>
15
19
#include < optional>
@@ -28,67 +32,329 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
28
32
return std::nullopt;
29
33
}
30
34
31
- std::optional<SmallVector<ReassociationIndices>>
32
- mlir::getReassociationIndicesForCollapse (ArrayRef<int64_t > sourceShape,
33
- ArrayRef<int64_t > targetShape) {
34
- if (sourceShape.size () <= targetShape.size ())
35
- return std::nullopt;
36
- unsigned sourceDim = 0 ;
37
- SmallVector<ReassociationIndices> reassociationMap;
38
- reassociationMap.reserve (targetShape.size ());
35
+ namespace {
36
+ // / A simple struct to represent ReassociationIndices as an inclusive interval.
37
+ // / It's designed to be feasibly minimal, so the call sites should manage the
38
+ // / validity of the range manually.
39
+ struct ReassociationIndexRange {
40
+ // / FIXME: Signed type is used for consistency with ReassociationIndices.
41
+ // / We should consider refactoring all reassociation utilities to use unsigned
42
+ // / types.
43
+ int64_t leftIdx = 0 , rightIdx = 0 ;
44
+
45
+ // / Util for manual checks of the range's validity
46
+ LogicalResult verify () const {
47
+ return leftIdx >= 0 && (leftIdx <= rightIdx) ? success () : failure ();
48
+ }
49
+
50
+ // / Checks range's containment within another range. Treats the edges
51
+ // / non-exclusively.
52
+ bool isInRange (const ReassociationIndexRange &outerRange) const {
53
+ return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx ;
54
+ }
55
+
56
+ unsigned size () const {
57
+ assert (succeeded (verify ()));
58
+ return rightIdx - leftIdx + 1 ;
59
+ }
60
+ bool containsSingleIndex () const { return size () == 1 ; }
61
+
62
+ // / Collects indices that do not overlap between this and another range.
63
+ ReassociationIndices
64
+ getNonOverlappingIndicesWith (ReassociationIndexRange &rhs) const {
65
+ if (rightIdx < rhs.leftIdx ) {
66
+ // The intervals do not overlap - concatenate the indices from both.
67
+ auto jointFullIndices = getFullIndices ();
68
+ jointFullIndices.append (rhs.getFullIndices ());
69
+ return jointFullIndices;
70
+ }
71
+ ReassociationIndices result;
72
+ // Handle the chunk left of the overlapping range.
73
+ int64_t leftStart = std::min (leftIdx, rhs.leftIdx );
74
+ int64_t leftEnd = std::max (leftIdx, rhs.leftIdx );
75
+ llvm::append_range (result, llvm::seq (leftStart, leftEnd));
76
+ // Handle the chunk right of the overlapping range. Symmetrically, we should
77
+ // skip the edge of the overlap AND include the rightmost index.
78
+ int64_t rightStart = std::min (rightIdx, rhs.rightIdx ) + 1 ;
79
+ int64_t rightEnd = std::max (rightIdx, rhs.rightIdx );
80
+ if (rightStart < rightEnd)
81
+ llvm::append_range (result, llvm::seq_inclusive (rightStart, rightEnd));
82
+ return result;
83
+ }
84
+
85
+ // / Converts the range into ReassociationIndices.
86
+ ReassociationIndices getFullIndices () const {
87
+ ReassociationIndices result;
88
+ for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
89
+ result.push_back (idx);
90
+ }
91
+ return result;
92
+ }
93
+ };
94
+ } // namespace
95
+
96
+ // / Starting from `sourceStartIdx`, searches `sourceShape` for the first
97
+ // / sequence that can be collapsed into a dynamic dimension (at least one must
98
+ // / be present in the source).
99
+ // / By default, lazily returns once the first dynamic dimension has been found.
100
+ // / Setting `matchGreedily` as `true` will also mark all subsequent
101
+ // / source dimensions for collapsing into the target.
102
+ static FailureOr<ReassociationIndexRange>
103
+ findReassociationRangeForDynamicDim (ArrayRef<int64_t > sourceShape,
104
+ int64_t sourceStartIdx,
105
+ bool matchGreedily = false ) {
106
+ const unsigned numSourceDims = sourceShape.size ();
107
+ ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
108
+ std::optional<ReassociationIndexRange> resultRange = std::nullopt;
109
+
110
+ ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
111
+ for (; iterationRange.isInRange (sourceShapeAsRange);
112
+ iterationRange.rightIdx ++) {
113
+ int64_t sourceSize = sourceShape[iterationRange.rightIdx ];
114
+ if (sourceSize == ShapedType::kDynamic ) {
115
+ resultRange = iterationRange;
116
+ break ;
117
+ }
118
+ }
119
+ if (!resultRange)
120
+ return failure ();
121
+ if (matchGreedily)
122
+ resultRange->rightIdx = sourceShapeAsRange.rightIdx ;
123
+ return *resultRange;
124
+ }
39
125
40
- ReassociationIndices currIndices;
126
+ // / Starting from `sourceStartIdx`, searches `sourceShape` for the first
127
+ // / sequence of static dimensions such that their product matches `targetSize`.
128
+ // / By default, lazily returns once the product matches the target size. Setting
129
+ // / `matchGreedily` as `true` will append all neighboring unit dimensions
130
+ // / (dimensions of 1) to the match.
131
+ static FailureOr<ReassociationIndexRange>
132
+ findReassociationRangeForSize (ArrayRef<int64_t > sourceShape,
133
+ int64_t sourceStartIdx, int64_t targetSize,
134
+ bool matchGreedily = false ) {
135
+ const unsigned numSourceDims = sourceShape.size ();
136
+ ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
137
+ std::optional<ReassociationIndexRange> resultRange = std::nullopt;
138
+
139
+ ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
41
140
int64_t prodOfCollapsedDims = 1 ;
42
- while (sourceDim < sourceShape.size ()) {
43
- unsigned targetDim = reassociationMap.size ();
44
- // If we have mapped all the target dimensions stop and handle the remaining
45
- // tail of size-1 dimensions explicitly.
46
- if (targetDim == targetShape.size ())
141
+ while (iterationRange.isInRange (sourceShapeAsRange)) {
142
+ int64_t sourceSize = sourceShape[iterationRange.rightIdx ];
143
+ if (sourceSize == ShapedType::kDynamic ) {
144
+ // Reassociation for a static dim cannot include a dynamic dim. Reset
145
+ // induction variables to essentially restart the loop from the next
146
+ // source dimension.
147
+ prodOfCollapsedDims = 1 ;
148
+ iterationRange = {iterationRange.rightIdx + 1 ,
149
+ iterationRange.rightIdx + 1 };
150
+ continue ;
151
+ }
152
+ prodOfCollapsedDims *= sourceSize;
153
+ // If the target size has been exceeded without matching, we need to shift
154
+ // the range start right. From the start of the range, roll back the
155
+ // multiplication until the target size exceeds the product again.
156
+ while (prodOfCollapsedDims > targetSize &&
157
+ !iterationRange.containsSingleIndex ()) {
158
+ int64_t frontSourceSize = sourceShape[iterationRange.leftIdx ];
159
+ prodOfCollapsedDims /= frontSourceSize;
160
+ // Shrink the range rightwards
161
+ iterationRange.leftIdx ++;
162
+ }
163
+ // We could've reached the target size with the current dimension,
164
+ // also as a result of the above shift to right.
165
+ if (prodOfCollapsedDims == targetSize) {
166
+ resultRange = iterationRange;
47
167
break ;
168
+ }
169
+ // Increment the iteration range
170
+ iterationRange.rightIdx ++;
171
+ }
172
+ if (!resultRange)
173
+ return failure ();
174
+ if (matchGreedily) {
175
+ // We now want to collect all unit dimensions directly after the target
176
+ // product match. Advance the iterator to avoid OOB when the product match
177
+ // happens at the last element.
178
+ iterationRange.rightIdx ++;
179
+ while (iterationRange.isInRange (sourceShapeAsRange) &&
180
+ sourceShape[iterationRange.rightIdx ] == 1 ) {
181
+ resultRange = iterationRange;
182
+ iterationRange.rightIdx ++;
183
+ }
184
+ }
185
+ return *resultRange;
186
+ }
48
187
49
- int64_t currTargetShape = targetShape[targetDim];
50
- while (sourceDim < (sourceShape.size () - 1 ) &&
51
- sourceShape[sourceDim] != ShapedType::kDynamic &&
52
- prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
53
- prodOfCollapsedDims *= sourceShape[sourceDim];
54
- currIndices.push_back (sourceDim++);
188
+ // / Attempts to find a valid collapsing reassociation of `sourceShape` into
189
+ // / `targetShape` through a simple traversal. If successful, an array of source
190
+ // / index ranges is returned, correspondingly to each dimension in the target
191
+ // / shape. The resulting indices shall fully cover the `sourceShape` without
192
+ // / overlaps.
193
+ // /
194
+ // / The algorithm is essentially a lazy one, searching for non-greedy matches -
195
+ // / it will only yield a greedy match for the last target dimension.
196
+ // / FIXME: The algorithm can only backtrack when it needs to append an offset
197
+ // / for a static target dimension to the preceding dynamic one (this retains the
198
+ // / linear complexity). As feasible, consider adding further backtracking
199
+ // / routines to enable more reassociations, e.g.:
200
+ // / - ?x2x?x2 into ?x2
201
+ static FailureOr<SmallVector<ReassociationIndexRange>>
202
+ findReassociationRangesForCollapse (ArrayRef<int64_t > sourceShape,
203
+ ArrayRef<int64_t > targetShape) {
204
+ unsigned numSourceDims = sourceShape.size (),
205
+ numTargetDims = targetShape.size ();
206
+ assert (numSourceDims > numTargetDims);
207
+ ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
208
+
209
+ SmallVector<ReassociationIndexRange> reassocRanges;
210
+ reassocRanges.reserve (numTargetDims);
211
+ // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
212
+ // cases, e.g.:
213
+ // - ?x2x3x5 into ?x15
214
+ std::optional<int64_t > prevTargetSize = std::nullopt;
215
+ for (unsigned targetDimIdx = 0 , sourceDimIdx = 0 ;
216
+ targetDimIdx < numTargetDims; ++targetDimIdx) {
217
+ int64_t targetSize = targetShape[targetDimIdx];
218
+ // Simply check if there are any subsequent target dimensions left - if not,
219
+ // the match must be made greedily.
220
+ bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1 ;
221
+ FailureOr<ReassociationIndexRange> sourceRange;
222
+ if (targetSize == ShapedType::kDynamic ) {
223
+ sourceRange = findReassociationRangeForDynamicDim (
224
+ sourceShape, sourceDimIdx, shouldMatchGreedily);
225
+ } else {
226
+ sourceRange = findReassociationRangeForSize (
227
+ sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
55
228
}
56
229
57
- // If the current expanded dimension is dynamic, then the collapsed
58
- // dimensions should also be dynamic and product of all previous unprocessed
59
- // dimensions of the expanded shape should be 1.
60
- if (sourceShape[sourceDim] == ShapedType::kDynamic &&
61
- (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1 ))
62
- return std::nullopt;
63
-
64
- // If the collapsed dim is dynamic, the current expanded dim should also
65
- // be dynamic.
66
- if (currTargetShape == ShapedType::kDynamic &&
67
- sourceShape[sourceDim] != ShapedType::kDynamic )
68
- return std::nullopt;
69
-
70
- // For static shapes, if the product of dimensions of the expanded shape
71
- // should match the collapsed dimension shape.
72
- if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
73
- return std::nullopt;
74
-
75
- currIndices.push_back (sourceDim++);
76
- reassociationMap.emplace_back (ReassociationIndices{});
77
- std::swap (reassociationMap.back (), currIndices);
78
- prodOfCollapsedDims = 1 ;
230
+ // Run sanity checks on the returned index range.
231
+ if (failed (sourceRange) || failed (sourceRange->verify ()) ||
232
+ !sourceRange->isInRange (sourceShapeAsRange))
233
+ return failure ();
234
+ if (sourceRange->leftIdx > sourceDimIdx) {
235
+ // If some source dimensions had to be skipped in order to find a match,
236
+ // they must be collapsed into the directly preceding dynamic dimension.
237
+ if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic )
238
+ return failure ();
239
+ reassocRanges.back ().rightIdx = sourceRange->leftIdx - 1 ;
240
+ }
241
+
242
+ // Store the gathered information as required for the next iteration.
243
+ prevTargetSize = targetSize;
244
+ sourceDimIdx = sourceRange->rightIdx + 1 ;
245
+ reassocRanges.push_back (*sourceRange);
79
246
}
80
- // All the dimensions in the target must have been processed.
81
- if (reassociationMap.size () != targetShape.size ())
247
+ // Fail if the source shape wasn't a full match for the target shape. We only
248
+ // need to check the last recorded index - any other gaps should have been
249
+ // mended by the main loop.
250
+ if (reassocRanges.back ().rightIdx < sourceShapeAsRange.rightIdx )
251
+ return failure ();
252
+ return reassocRanges;
253
+ }
254
+
255
+ // / A variant of `findReassociationRangesForCollapse(...)` that can also scan
256
+ // / the shapes right-to-left.
257
+ static FailureOr<SmallVector<ReassociationIndexRange>>
258
+ findReassociationRangesForCollapse (ArrayRef<int64_t > sourceShape,
259
+ ArrayRef<int64_t > targetShape,
260
+ bool iterateRightToLeft) {
261
+ if (!iterateRightToLeft)
262
+ return findReassociationRangesForCollapse (sourceShape, targetShape);
263
+ // NB: To iterate right-to-left, we currently reverse the shapes and then
264
+ // reverse the result back. The reversed shapes must not be temporary, as
265
+ // we're passing through an ArrayRef.
266
+ // FIXME: It would be preferable to avoid the expensive copies. At the moment,
267
+ // this approach is chosen for readability of the main implementation.
268
+ std::vector<int64_t > sourceToReverse = sourceShape.vec (),
269
+ targetToReverse = targetShape.vec ();
270
+ std::reverse (sourceToReverse.begin (), sourceToReverse.end ());
271
+ std::reverse (targetToReverse.begin (), targetToReverse.end ());
272
+ auto invertedRanges =
273
+ findReassociationRangesForCollapse (sourceToReverse, targetToReverse);
274
+ if (failed (invertedRanges))
275
+ return failure ();
276
+ SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
277
+ unsigned numSourceDims = sourceShape.size ();
278
+ // We have received the ranges for inverted shapes. Now we have to invert
279
+ // the ranges back to correspond with the original source shape.
280
+ for (auto &range : rangesToInvert) {
281
+ int64_t invLeftIdx = range.leftIdx , invRightIdx = range.rightIdx ;
282
+ range.leftIdx = numSourceDims - 1 - invRightIdx;
283
+ range.rightIdx = numSourceDims - 1 - invLeftIdx;
284
+ }
285
+ // Also invert the ordering of the ranges to correspond with the original
286
+ // target shape.
287
+ std::reverse (rangesToInvert.begin (), rangesToInvert.end ());
288
+ return rangesToInvert;
289
+ }
290
+
291
+ std::optional<SmallVector<ReassociationIndices>>
292
+ mlir::getReassociationIndicesForCollapse (ArrayRef<int64_t > sourceShape,
293
+ ArrayRef<int64_t > targetShape) {
294
+ unsigned numSourceDims = sourceShape.size (),
295
+ numTargetDims = targetShape.size ();
296
+ // We're supposed to search for a collapsing reassociation. If the sizes
297
+ // match, there's no actual collapsing taking place - it's either a no-op or a
298
+ // `tensor.reshape`-style reassociation (that would be beyond the scope of
299
+ // this utility).
300
+ if (numSourceDims <= numTargetDims)
301
+ return std::nullopt;
302
+ // Early handling for scalar target types.
303
+ if (numTargetDims == 0 ) {
304
+ ReassociationIndices allSourceIndices;
305
+ allSourceIndices.reserve (numSourceDims);
306
+ for (unsigned sourceDimIdx = 0 ; sourceDimIdx < numSourceDims;
307
+ ++sourceDimIdx) {
308
+ int64_t sourceSize = sourceShape[sourceDimIdx];
309
+ // All source dimensions must be unit or dynamic.
310
+ if (sourceSize != 1 && sourceSize != ShapedType::kDynamic )
311
+ return std::nullopt;
312
+ allSourceIndices.push_back (sourceDimIdx);
313
+ }
314
+ return SmallVector<ReassociationIndices>{allSourceIndices};
315
+ }
316
+
317
+ // Collect source ranges by iterating over the target shape left-to-right.
318
+ FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
319
+ findReassociationRangesForCollapse (sourceShape, targetShape);
320
+ if (failed (maybeForwardRanges))
321
+ return std::nullopt;
322
+ auto &ranges = *maybeForwardRanges;
323
+ // Now do the same in reverse. We need to get another valid reassociation
324
+ // through some other strategy, and then compare the results in order to
325
+ // disambiguate mixed subshapes, such as:
326
+ // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
327
+ // This leads us to lose some of the reassociation opportunities that can only
328
+ // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
329
+ // backtracking, the algorithm will fail right-to-left. However, this is the
330
+ // best way to preserve correctness.
331
+ FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
332
+ findReassociationRangesForCollapse (sourceShape, targetShape,
333
+ /* iterateRightToLeft=*/ true );
334
+ if (failed (maybeReverseRanges))
335
+ return std::nullopt;
336
+ auto &reverseRanges = *maybeReverseRanges;
337
+
338
+ if (ranges.size () != numTargetDims || reverseRanges.size () != numTargetDims)
82
339
return std::nullopt;
83
- // Process any remaining entries in the source shape. They all need to be
84
- // 1 or dynamic.
85
- for (; sourceDim < sourceShape.size (); sourceDim++) {
86
- if (sourceShape[sourceDim] != ShapedType::kDynamic &&
87
- sourceShape[sourceDim] != 1 )
88
- return std::nullopt;
89
- // The map is empty when the target type is a scalar.
90
- if (!reassociationMap.empty ())
91
- reassociationMap.back ().push_back (sourceDim);
340
+ // Now we can check for ambiguity of each target dimension's reassociation. If
341
+ // successful, we put the full indices into our result map for the target
342
+ // shape.
343
+ SmallVector<ReassociationIndices> reassociationMap (numTargetDims);
344
+ for (unsigned targetDimIdx = 0 ; targetDimIdx < numTargetDims;
345
+ ++targetDimIdx) {
346
+ ReassociationIndexRange &range = ranges[targetDimIdx];
347
+ ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
348
+ // Get non-overlapping indices between the ranges
349
+ ReassociationIndices nonMatchingIndices =
350
+ range.getNonOverlappingIndicesWith (reverseRange);
351
+ // Unit dimensions can be collapsed wherever - this is the only ambiguity
352
+ // that we allow.
353
+ for (int64_t sourceDimIdx : nonMatchingIndices) {
354
+ if (sourceShape[sourceDimIdx] != 1 )
355
+ return std::nullopt;
356
+ }
357
+ reassociationMap[targetDimIdx] = range.getFullIndices ();
92
358
}
93
359
return reassociationMap;
94
360
}
0 commit comments