10
10
11
11
#include " mlir/IR/AffineMap.h"
12
12
#include " mlir/IR/Builders.h"
13
- #include " mlir/IR/BuiltinTypeInterfaces.h"
14
- #include " llvm/ADT/ArrayRef.h"
15
- #include " llvm/ADT/SmallVector.h"
16
- #include " llvm/Support/LogicalResult.h"
17
13
18
14
#include < numeric>
19
15
#include < optional>
@@ -32,329 +28,67 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
32
28
return std::nullopt;
33
29
}
34
30
35
- namespace {
36
- // / A simple struct to represent ReassociationIndices as an inclusive interval.
37
- // / It's designed to be feasibly minimal, so the call sites should manage the
38
- // / validity of the range manually.
39
- struct ReassociationIndexRange {
40
- // / FIXME: Signed type is used for consistency with ReassociationIndices.
41
- // / We should consider refactoring all reassociation utilities to use unsigned
42
- // / types.
43
- int64_t leftIdx = 0 , rightIdx = 0 ;
44
-
45
- // / Util for manual checks of the range's validity
46
- LogicalResult verify () const {
47
- return leftIdx >= 0 && (leftIdx <= rightIdx) ? success () : failure ();
48
- }
49
-
50
- // / Checks range's containment within another range. Treats the edges
51
- // / non-exclusively.
52
- bool isInRange (const ReassociationIndexRange &outerRange) const {
53
- return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx ;
54
- }
55
-
56
- unsigned size () const {
57
- assert (succeeded (verify ()));
58
- return rightIdx - leftIdx + 1 ;
59
- }
60
- bool containsSingleIndex () const { return size () == 1 ; }
61
-
62
- // / Collects indices that do not overlap between this and another range.
63
- ReassociationIndices
64
- getNonOverlappingIndicesWith (ReassociationIndexRange &rhs) const {
65
- if (rightIdx < rhs.leftIdx ) {
66
- // The intervals do not overlap - concatenate the indices from both.
67
- auto jointFullIndices = getFullIndices ();
68
- jointFullIndices.append (rhs.getFullIndices ());
69
- return jointFullIndices;
70
- }
71
- ReassociationIndices result;
72
- // Handle the chunk left of the overlapping range.
73
- int64_t leftStart = std::min (leftIdx, rhs.leftIdx );
74
- int64_t leftEnd = std::max (leftIdx, rhs.leftIdx );
75
- llvm::append_range (result, llvm::seq (leftStart, leftEnd));
76
- // Handle the chunk right of the overlapping range. Symmetrically, we should
77
- // skip the edge of the overlap AND include the rightmost index.
78
- int64_t rightStart = std::min (rightIdx, rhs.rightIdx ) + 1 ;
79
- int64_t rightEnd = std::max (rightIdx, rhs.rightIdx );
80
- if (rightStart < rightEnd)
81
- llvm::append_range (result, llvm::seq_inclusive (rightStart, rightEnd));
82
- return result;
83
- }
84
-
85
- // / Converts the range into ReassociationIndices.
86
- ReassociationIndices getFullIndices () const {
87
- ReassociationIndices result;
88
- for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
89
- result.push_back (idx);
90
- }
91
- return result;
92
- }
93
- };
94
- } // namespace
95
-
96
- // / Starting from `sourceStartIdx`, searches `sourceShape` for the first
97
- // / sequence that can be collapsed into a dynamic dimension (at least one must
98
- // / be present in the source).
99
- // / By default, lazily returns once the first dynamic dimension has been found.
100
- // / Setting `matchGreedily` as `true` will also mark all subsequent
101
- // / source dimensions for collapsing into the target.
102
- static FailureOr<ReassociationIndexRange>
103
- findReassociationRangeForDynamicDim (ArrayRef<int64_t > sourceShape,
104
- int64_t sourceStartIdx,
105
- bool matchGreedily = false ) {
106
- const unsigned numSourceDims = sourceShape.size ();
107
- ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
108
- std::optional<ReassociationIndexRange> resultRange = std::nullopt;
109
-
110
- ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
111
- for (; iterationRange.isInRange (sourceShapeAsRange);
112
- iterationRange.rightIdx ++) {
113
- int64_t sourceSize = sourceShape[iterationRange.rightIdx ];
114
- if (sourceSize == ShapedType::kDynamic ) {
115
- resultRange = iterationRange;
116
- break ;
117
- }
118
- }
119
- if (!resultRange)
120
- return failure ();
121
- if (matchGreedily)
122
- resultRange->rightIdx = sourceShapeAsRange.rightIdx ;
123
- return *resultRange;
124
- }
31
+ std::optional<SmallVector<ReassociationIndices>>
32
+ mlir::getReassociationIndicesForCollapse (ArrayRef<int64_t > sourceShape,
33
+ ArrayRef<int64_t > targetShape) {
34
+ if (sourceShape.size () <= targetShape.size ())
35
+ return std::nullopt;
36
+ unsigned sourceDim = 0 ;
37
+ SmallVector<ReassociationIndices> reassociationMap;
38
+ reassociationMap.reserve (targetShape.size ());
125
39
126
- // / Starting from `sourceStartIdx`, searches `sourceShape` for the first
127
- // / sequence of static dimensions such that their product matches `targetSize`.
128
- // / By default, lazily returns once the product matches the target size. Setting
129
- // / `matchGreedily` as `true` will append all neighboring unit dimensions
130
- // / (dimensions of 1) to the match.
131
- static FailureOr<ReassociationIndexRange>
132
- findReassociationRangeForSize (ArrayRef<int64_t > sourceShape,
133
- int64_t sourceStartIdx, int64_t targetSize,
134
- bool matchGreedily = false ) {
135
- const unsigned numSourceDims = sourceShape.size ();
136
- ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
137
- std::optional<ReassociationIndexRange> resultRange = std::nullopt;
138
-
139
- ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
40
+ ReassociationIndices currIndices;
140
41
int64_t prodOfCollapsedDims = 1 ;
141
- while (iterationRange.isInRange (sourceShapeAsRange)) {
142
- int64_t sourceSize = sourceShape[iterationRange.rightIdx ];
143
- if (sourceSize == ShapedType::kDynamic ) {
144
- // Reassociation for a static dim cannot include a dynamic dim. Reset
145
- // induction variables to essentially restart the loop from the next
146
- // source dimension.
147
- prodOfCollapsedDims = 1 ;
148
- iterationRange = {iterationRange.rightIdx + 1 ,
149
- iterationRange.rightIdx + 1 };
150
- continue ;
151
- }
152
- prodOfCollapsedDims *= sourceSize;
153
- // If the target size has been exceeded without matching, we need to shift
154
- // the range start right. From the start of the range, roll back the
155
- // multiplication until the target size exceeds the product again.
156
- while (prodOfCollapsedDims > targetSize &&
157
- !iterationRange.containsSingleIndex ()) {
158
- int64_t frontSourceSize = sourceShape[iterationRange.leftIdx ];
159
- prodOfCollapsedDims /= frontSourceSize;
160
- // Shrink the range rightwards
161
- iterationRange.leftIdx ++;
162
- }
163
- // We could've reached the target size with the current dimension,
164
- // also as a result of the above shift to right.
165
- if (prodOfCollapsedDims == targetSize) {
166
- resultRange = iterationRange;
42
+ while (sourceDim < sourceShape.size ()) {
43
+ unsigned targetDim = reassociationMap.size ();
44
+ // If we have mapped all the target dimensions stop and handle the remaining
45
+ // tail of size-1 dimensions explicitly.
46
+ if (targetDim == targetShape.size ())
167
47
break ;
168
- }
169
- // Increment the iteration range
170
- iterationRange.rightIdx ++;
171
- }
172
- if (!resultRange)
173
- return failure ();
174
- if (matchGreedily) {
175
- // We now want to collect all unit dimensions directly after the target
176
- // product match. Advance the iterator to avoid OOB when the product match
177
- // happens at the last element.
178
- iterationRange.rightIdx ++;
179
- while (iterationRange.isInRange (sourceShapeAsRange) &&
180
- sourceShape[iterationRange.rightIdx ] == 1 ) {
181
- resultRange = iterationRange;
182
- iterationRange.rightIdx ++;
183
- }
184
- }
185
- return *resultRange;
186
- }
187
48
188
- // / Attempts to find a valid collapsing reassociation of `sourceShape` into
189
- // / `targetShape` through a simple traversal. If successful, an array of source
190
- // / index ranges is returned, correspondingly to each dimension in the target
191
- // / shape. The resulting indices shall fully cover the `sourceShape` without
192
- // / overlaps.
193
- // /
194
- // / The algorithm is essentially a lazy one, searching for non-greedy matches -
195
- // / it will only yield a greedy match for the last target dimension.
196
- // / FIXME: The algorithm can only backtrack when it needs to append an offset
197
- // / for a static target dimension to the preceding dynamic one (this retains the
198
- // / linear complexity). As feasible, consider adding further backtracking
199
- // / routines to enable more reassociations, e.g.:
200
- // / - ?x2x?x2 into ?x2
201
- static FailureOr<SmallVector<ReassociationIndexRange>>
202
- findReassociationRangesForCollapse (ArrayRef<int64_t > sourceShape,
203
- ArrayRef<int64_t > targetShape) {
204
- unsigned numSourceDims = sourceShape.size (),
205
- numTargetDims = targetShape.size ();
206
- assert (numSourceDims > numTargetDims);
207
- ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
208
-
209
- SmallVector<ReassociationIndexRange> reassocRanges;
210
- reassocRanges.reserve (numTargetDims);
211
- // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
212
- // cases, e.g.:
213
- // - ?x2x3x5 into ?x15
214
- std::optional<int64_t > prevTargetSize = std::nullopt;
215
- for (unsigned targetDimIdx = 0 , sourceDimIdx = 0 ;
216
- targetDimIdx < numTargetDims; ++targetDimIdx) {
217
- int64_t targetSize = targetShape[targetDimIdx];
218
- // Simply check if there are any subsequent target dimensions left - if not,
219
- // the match must be made greedily.
220
- bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1 ;
221
- FailureOr<ReassociationIndexRange> sourceRange;
222
- if (targetSize == ShapedType::kDynamic ) {
223
- sourceRange = findReassociationRangeForDynamicDim (
224
- sourceShape, sourceDimIdx, shouldMatchGreedily);
225
- } else {
226
- sourceRange = findReassociationRangeForSize (
227
- sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
49
+ int64_t currTargetShape = targetShape[targetDim];
50
+ while (sourceDim < (sourceShape.size () - 1 ) &&
51
+ sourceShape[sourceDim] != ShapedType::kDynamic &&
52
+ prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
53
+ prodOfCollapsedDims *= sourceShape[sourceDim];
54
+ currIndices.push_back (sourceDim++);
228
55
}
229
56
230
- // Run sanity checks on the returned index range.
231
- if (failed (sourceRange) || failed (sourceRange->verify ()) ||
232
- !sourceRange->isInRange (sourceShapeAsRange))
233
- return failure ();
234
- if (sourceRange->leftIdx > sourceDimIdx) {
235
- // If some source dimensions had to be skipped in order to find a match,
236
- // they must be collapsed into the directly preceding dynamic dimension.
237
- if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic )
238
- return failure ();
239
- reassocRanges.back ().rightIdx = sourceRange->leftIdx - 1 ;
240
- }
241
-
242
- // Store the gathered information as required for the next iteration.
243
- prevTargetSize = targetSize;
244
- sourceDimIdx = sourceRange->rightIdx + 1 ;
245
- reassocRanges.push_back (*sourceRange);
57
+ // If the current expanded dimension is dynamic, then the collapsed
58
+ // dimensions should also be dynamic and product of all previous unprocessed
59
+ // dimensions of the expanded shape should be 1.
60
+ if (sourceShape[sourceDim] == ShapedType::kDynamic &&
61
+ (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1 ))
62
+ return std::nullopt;
63
+
64
+ // If the collapsed dim is dynamic, the current expanded dim should also
65
+ // be dynamic.
66
+ if (currTargetShape == ShapedType::kDynamic &&
67
+ sourceShape[sourceDim] != ShapedType::kDynamic )
68
+ return std::nullopt;
69
+
70
+ // For static shapes, if the product of dimensions of the expanded shape
71
+ // should match the collapsed dimension shape.
72
+ if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
73
+ return std::nullopt;
74
+
75
+ currIndices.push_back (sourceDim++);
76
+ reassociationMap.emplace_back (ReassociationIndices{});
77
+ std::swap (reassociationMap.back (), currIndices);
78
+ prodOfCollapsedDims = 1 ;
246
79
}
247
- // Fail if the source shape wasn't a full match for the target shape. We only
248
- // need to check the last recorded index - any other gaps should have been
249
- // mended by the main loop.
250
- if (reassocRanges.back ().rightIdx < sourceShapeAsRange.rightIdx )
251
- return failure ();
252
- return reassocRanges;
253
- }
254
-
255
- // / A variant of `findReassociationRangesForCollapse(...)` that can also scan
256
- // / the shapes right-to-left.
257
- static FailureOr<SmallVector<ReassociationIndexRange>>
258
- findReassociationRangesForCollapse (ArrayRef<int64_t > sourceShape,
259
- ArrayRef<int64_t > targetShape,
260
- bool iterateRightToLeft) {
261
- if (!iterateRightToLeft)
262
- return findReassociationRangesForCollapse (sourceShape, targetShape);
263
- // NB: To iterate right-to-left, we currently reverse the shapes and then
264
- // reverse the result back. The reversed shapes must not be temporary, as
265
- // we're passing through an ArrayRef.
266
- // FIXME: It would be preferable to avoid the expensive copies. At the moment,
267
- // this approach is chosen for readability of the main implementation.
268
- std::vector<int64_t > sourceToReverse = sourceShape.vec (),
269
- targetToReverse = targetShape.vec ();
270
- std::reverse (sourceToReverse.begin (), sourceToReverse.end ());
271
- std::reverse (targetToReverse.begin (), targetToReverse.end ());
272
- auto invertedRanges =
273
- findReassociationRangesForCollapse (sourceToReverse, targetToReverse);
274
- if (failed (invertedRanges))
275
- return failure ();
276
- SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
277
- unsigned numSourceDims = sourceShape.size ();
278
- // We have received the ranges for inverted shapes. Now we have to invert
279
- // the ranges back to correspond with the original source shape.
280
- for (auto &range : rangesToInvert) {
281
- int64_t invLeftIdx = range.leftIdx , invRightIdx = range.rightIdx ;
282
- range.leftIdx = numSourceDims - 1 - invRightIdx;
283
- range.rightIdx = numSourceDims - 1 - invLeftIdx;
284
- }
285
- // Also invert the ordering of the ranges to correspond with the original
286
- // target shape.
287
- std::reverse (rangesToInvert.begin (), rangesToInvert.end ());
288
- return rangesToInvert;
289
- }
290
-
291
- std::optional<SmallVector<ReassociationIndices>>
292
- mlir::getReassociationIndicesForCollapse (ArrayRef<int64_t > sourceShape,
293
- ArrayRef<int64_t > targetShape) {
294
- unsigned numSourceDims = sourceShape.size (),
295
- numTargetDims = targetShape.size ();
296
- // We're supposed to search for a collapsing reassociation. If the sizes
297
- // match, there's no actual collapsing taking place - it's either a no-op or a
298
- // `tensor.reshape`-style reassociation (that would be beyond the scope of
299
- // this utility).
300
- if (numSourceDims <= numTargetDims)
301
- return std::nullopt;
302
- // Early handling for scalar target types.
303
- if (numTargetDims == 0 ) {
304
- ReassociationIndices allSourceIndices;
305
- allSourceIndices.reserve (numSourceDims);
306
- for (unsigned sourceDimIdx = 0 ; sourceDimIdx < numSourceDims;
307
- ++sourceDimIdx) {
308
- int64_t sourceSize = sourceShape[sourceDimIdx];
309
- // All source dimensions must be unit or dynamic.
310
- if (sourceSize != 1 && sourceSize != ShapedType::kDynamic )
311
- return std::nullopt;
312
- allSourceIndices.push_back (sourceDimIdx);
313
- }
314
- return SmallVector<ReassociationIndices>{allSourceIndices};
315
- }
316
-
317
- // Collect source ranges by iterating over the target shape left-to-right.
318
- FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
319
- findReassociationRangesForCollapse (sourceShape, targetShape);
320
- if (failed (maybeForwardRanges))
321
- return std::nullopt;
322
- auto &ranges = *maybeForwardRanges;
323
- // Now do the same in reverse. We need to get another valid reassociation
324
- // through some other strategy, and then compare the results in order to
325
- // disambiguate mixed subshapes, such as:
326
- // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
327
- // This leads us to lose some of the reassociation opportunities that can only
328
- // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
329
- // backtracking, the algorithm will fail right-to-left. However, this is the
330
- // best way to preserve correctness.
331
- FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
332
- findReassociationRangesForCollapse (sourceShape, targetShape,
333
- /* iterateRightToLeft=*/ true );
334
- if (failed (maybeReverseRanges))
335
- return std::nullopt;
336
- auto &reverseRanges = *maybeReverseRanges;
337
-
338
- if (ranges.size () != numTargetDims || reverseRanges.size () != numTargetDims)
80
+ // All the dimensions in the target must have been processed.
81
+ if (reassociationMap.size () != targetShape.size ())
339
82
return std::nullopt;
340
- // Now we can check for ambiguity of each target dimension's reassociation. If
341
- // successful, we put the full indices into our result map for the target
342
- // shape.
343
- SmallVector<ReassociationIndices> reassociationMap (numTargetDims);
344
- for (unsigned targetDimIdx = 0 ; targetDimIdx < numTargetDims;
345
- ++targetDimIdx) {
346
- ReassociationIndexRange &range = ranges[targetDimIdx];
347
- ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
348
- // Get non-overlapping indices between the ranges
349
- ReassociationIndices nonMatchingIndices =
350
- range.getNonOverlappingIndicesWith (reverseRange);
351
- // Unit dimensions can be collapsed wherever - this is the only ambiguity
352
- // that we allow.
353
- for (int64_t sourceDimIdx : nonMatchingIndices) {
354
- if (sourceShape[sourceDimIdx] != 1 )
355
- return std::nullopt;
356
- }
357
- reassociationMap[targetDimIdx] = range.getFullIndices ();
83
+ // Process any remaining entries in the source shape. They all need to be
84
+ // 1 or dynamic.
85
+ for (; sourceDim < sourceShape.size (); sourceDim++) {
86
+ if (sourceShape[sourceDim] != ShapedType::kDynamic &&
87
+ sourceShape[sourceDim] != 1 )
88
+ return std::nullopt;
89
+ // The map is empty when the target type is a scalar.
90
+ if (!reassociationMap.empty ())
91
+ reassociationMap.back ().push_back (sourceDim);
358
92
}
359
93
return reassociationMap;
360
94
}
0 commit comments