Skip to content

Commit 6e5a142

Browse files
IanWood1AGindinson
andauthored
[mlir] Reapply "Loosen restrictions on folding dynamic reshapes" (#142827)
The original PR #137963 had a nvidia bot failure. This appears to be a flaky test because rerunning the build was successful. This change needs commit 6f2ba47 to fix incorrect usage of `getReassociationIndicesForCollapse`. Reverts #142639 Co-authored-by: Artem Gindinson <[email protected]>
1 parent 2d35b56 commit 6e5a142

File tree

5 files changed

+560
-59
lines changed

5 files changed

+560
-59
lines changed

mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Lines changed: 319 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
#include "mlir/IR/AffineMap.h"
1212
#include "mlir/IR/Builders.h"
13+
#include "mlir/IR/BuiltinTypeInterfaces.h"
14+
#include "llvm/ADT/ArrayRef.h"
15+
#include "llvm/ADT/SmallVector.h"
16+
#include "llvm/Support/LogicalResult.h"
1317

1418
#include <numeric>
1519
#include <optional>
@@ -28,67 +32,329 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
2832
return std::nullopt;
2933
}
3034

31-
std::optional<SmallVector<ReassociationIndices>>
32-
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
33-
ArrayRef<int64_t> targetShape) {
34-
if (sourceShape.size() <= targetShape.size())
35-
return std::nullopt;
36-
unsigned sourceDim = 0;
37-
SmallVector<ReassociationIndices> reassociationMap;
38-
reassociationMap.reserve(targetShape.size());
35+
namespace {
36+
/// A simple struct to represent ReassociationIndices as an inclusive interval.
37+
/// It's designed to be feasibly minimal, so the call sites should manage the
38+
/// validity of the range manually.
39+
struct ReassociationIndexRange {
40+
/// FIXME: Signed type is used for consistency with ReassociationIndices.
41+
/// We should consider refactoring all reassociation utilities to use unsigned
42+
/// types.
43+
int64_t leftIdx = 0, rightIdx = 0;
44+
45+
/// Util for manual checks of the range's validity
46+
LogicalResult verify() const {
47+
return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
48+
}
49+
50+
/// Checks range's containment within another range. Treats the edges
51+
/// non-exclusively.
52+
bool isInRange(const ReassociationIndexRange &outerRange) const {
53+
return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
54+
}
55+
56+
unsigned size() const {
57+
assert(succeeded(verify()));
58+
return rightIdx - leftIdx + 1;
59+
}
60+
bool containsSingleIndex() const { return size() == 1; }
61+
62+
/// Collects indices that do not overlap between this and another range.
63+
ReassociationIndices
64+
getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {
65+
if (rightIdx < rhs.leftIdx) {
66+
// The intervals do not overlap - concatenate the indices from both.
67+
auto jointFullIndices = getFullIndices();
68+
jointFullIndices.append(rhs.getFullIndices());
69+
return jointFullIndices;
70+
}
71+
ReassociationIndices result;
72+
// Handle the chunk left of the overlapping range.
73+
int64_t leftStart = std::min(leftIdx, rhs.leftIdx);
74+
int64_t leftEnd = std::max(leftIdx, rhs.leftIdx);
75+
llvm::append_range(result, llvm::seq(leftStart, leftEnd));
76+
// Handle the chunk right of the overlapping range. Symmetrically, we should
77+
// skip the edge of the overlap AND include the rightmost index.
78+
int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1;
79+
int64_t rightEnd = std::max(rightIdx, rhs.rightIdx);
80+
if (rightStart < rightEnd)
81+
llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));
82+
return result;
83+
}
84+
85+
/// Converts the range into ReassociationIndices.
86+
ReassociationIndices getFullIndices() const {
87+
ReassociationIndices result;
88+
for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
89+
result.push_back(idx);
90+
}
91+
return result;
92+
}
93+
};
94+
} // namespace
95+
96+
/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
97+
/// sequence that can be collapsed into a dynamic dimension (at least one must
98+
/// be present in the source).
99+
/// By default, lazily returns once the first dynamic dimension has been found.
100+
/// Setting `matchGreedily` as `true` will also mark all subsequent
101+
/// source dimensions for collapsing into the target.
102+
static FailureOr<ReassociationIndexRange>
103+
findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
104+
int64_t sourceStartIdx,
105+
bool matchGreedily = false) {
106+
const unsigned numSourceDims = sourceShape.size();
107+
ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
108+
std::optional<ReassociationIndexRange> resultRange = std::nullopt;
109+
110+
ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
111+
for (; iterationRange.isInRange(sourceShapeAsRange);
112+
iterationRange.rightIdx++) {
113+
int64_t sourceSize = sourceShape[iterationRange.rightIdx];
114+
if (sourceSize == ShapedType::kDynamic) {
115+
resultRange = iterationRange;
116+
break;
117+
}
118+
}
119+
if (!resultRange)
120+
return failure();
121+
if (matchGreedily)
122+
resultRange->rightIdx = sourceShapeAsRange.rightIdx;
123+
return *resultRange;
124+
}
39125

40-
ReassociationIndices currIndices;
126+
/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
127+
/// sequence of static dimensions such that their product matches `targetSize`.
128+
/// By default, lazily returns once the product matches the target size. Setting
129+
/// `matchGreedily` as `true` will append all neighboring unit dimensions
130+
/// (dimensions of 1) to the match.
131+
static FailureOr<ReassociationIndexRange>
132+
findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
133+
int64_t sourceStartIdx, int64_t targetSize,
134+
bool matchGreedily = false) {
135+
const unsigned numSourceDims = sourceShape.size();
136+
ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
137+
std::optional<ReassociationIndexRange> resultRange = std::nullopt;
138+
139+
ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
41140
int64_t prodOfCollapsedDims = 1;
42-
while (sourceDim < sourceShape.size()) {
43-
unsigned targetDim = reassociationMap.size();
44-
// If we have mapped all the target dimensions stop and handle the remaining
45-
// tail of size-1 dimensions explicitly.
46-
if (targetDim == targetShape.size())
141+
while (iterationRange.isInRange(sourceShapeAsRange)) {
142+
int64_t sourceSize = sourceShape[iterationRange.rightIdx];
143+
if (sourceSize == ShapedType::kDynamic) {
144+
// Reassociation for a static dim cannot include a dynamic dim. Reset
145+
// induction variables to essentially restart the loop from the next
146+
// source dimension.
147+
prodOfCollapsedDims = 1;
148+
iterationRange = {iterationRange.rightIdx + 1,
149+
iterationRange.rightIdx + 1};
150+
continue;
151+
}
152+
prodOfCollapsedDims *= sourceSize;
153+
// If the target size has been exceeded without matching, we need to shift
154+
// the range start right. From the start of the range, roll back the
155+
// multiplication until the target size exceeds the product again.
156+
while (prodOfCollapsedDims > targetSize &&
157+
!iterationRange.containsSingleIndex()) {
158+
int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
159+
prodOfCollapsedDims /= frontSourceSize;
160+
// Shrink the range rightwards
161+
iterationRange.leftIdx++;
162+
}
163+
// We could've reached the target size with the current dimension,
164+
// also as a result of the above shift to right.
165+
if (prodOfCollapsedDims == targetSize) {
166+
resultRange = iterationRange;
47167
break;
168+
}
169+
// Increment the iteration range
170+
iterationRange.rightIdx++;
171+
}
172+
if (!resultRange)
173+
return failure();
174+
if (matchGreedily) {
175+
// We now want to collect all unit dimensions directly after the target
176+
// product match. Advance the iterator to avoid OOB when the product match
177+
// happens at the last element.
178+
iterationRange.rightIdx++;
179+
while (iterationRange.isInRange(sourceShapeAsRange) &&
180+
sourceShape[iterationRange.rightIdx] == 1) {
181+
resultRange = iterationRange;
182+
iterationRange.rightIdx++;
183+
}
184+
}
185+
return *resultRange;
186+
}
48187

49-
int64_t currTargetShape = targetShape[targetDim];
50-
while (sourceDim < (sourceShape.size() - 1) &&
51-
sourceShape[sourceDim] != ShapedType::kDynamic &&
52-
prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
53-
prodOfCollapsedDims *= sourceShape[sourceDim];
54-
currIndices.push_back(sourceDim++);
188+
/// Attempts to find a valid collapsing reassociation of `sourceShape` into
189+
/// `targetShape` through a simple traversal. If successful, an array of source
190+
/// index ranges is returned, correspondingly to each dimension in the target
191+
/// shape. The resulting indices shall fully cover the `sourceShape` without
192+
/// overlaps.
193+
///
194+
/// The algorithm is essentially a lazy one, searching for non-greedy matches -
195+
/// it will only yield a greedy match for the last target dimension.
196+
/// FIXME: The algorithm can only backtrack when it needs to append an offset
197+
/// for a static target dimension to the preceding dynamic one (this retains the
198+
/// linear complexity). As feasible, consider adding further backtracking
199+
/// routines to enable more reassociations, e.g.:
200+
/// - ?x2x?x2 into ?x2
201+
static FailureOr<SmallVector<ReassociationIndexRange>>
202+
findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
203+
ArrayRef<int64_t> targetShape) {
204+
unsigned numSourceDims = sourceShape.size(),
205+
numTargetDims = targetShape.size();
206+
assert(numSourceDims > numTargetDims);
207+
ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
208+
209+
SmallVector<ReassociationIndexRange> reassocRanges;
210+
reassocRanges.reserve(numTargetDims);
211+
// We'll iterate in strides of 2 to enable pseudo-backtracking for simple
212+
// cases, e.g.:
213+
// - ?x2x3x5 into ?x15
214+
std::optional<int64_t> prevTargetSize = std::nullopt;
215+
for (unsigned targetDimIdx = 0, sourceDimIdx = 0;
216+
targetDimIdx < numTargetDims; ++targetDimIdx) {
217+
int64_t targetSize = targetShape[targetDimIdx];
218+
// Simply check if there are any subsequent target dimensions left - if not,
219+
// the match must be made greedily.
220+
bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;
221+
FailureOr<ReassociationIndexRange> sourceRange;
222+
if (targetSize == ShapedType::kDynamic) {
223+
sourceRange = findReassociationRangeForDynamicDim(
224+
sourceShape, sourceDimIdx, shouldMatchGreedily);
225+
} else {
226+
sourceRange = findReassociationRangeForSize(
227+
sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
55228
}
56229

57-
// If the current expanded dimension is dynamic, then the collapsed
58-
// dimensions should also be dynamic and product of all previous unprocessed
59-
// dimensions of the expanded shape should be 1.
60-
if (sourceShape[sourceDim] == ShapedType::kDynamic &&
61-
(currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
62-
return std::nullopt;
63-
64-
// If the collapsed dim is dynamic, the current expanded dim should also
65-
// be dynamic.
66-
if (currTargetShape == ShapedType::kDynamic &&
67-
sourceShape[sourceDim] != ShapedType::kDynamic)
68-
return std::nullopt;
69-
70-
// For static shapes, if the product of dimensions of the expanded shape
71-
// should match the collapsed dimension shape.
72-
if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
73-
return std::nullopt;
74-
75-
currIndices.push_back(sourceDim++);
76-
reassociationMap.emplace_back(ReassociationIndices{});
77-
std::swap(reassociationMap.back(), currIndices);
78-
prodOfCollapsedDims = 1;
230+
// Run sanity checks on the returned index range.
231+
if (failed(sourceRange) || failed(sourceRange->verify()) ||
232+
!sourceRange->isInRange(sourceShapeAsRange))
233+
return failure();
234+
if (sourceRange->leftIdx > sourceDimIdx) {
235+
// If some source dimensions had to be skipped in order to find a match,
236+
// they must be collapsed into the directly preceding dynamic dimension.
237+
if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
238+
return failure();
239+
reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
240+
}
241+
242+
// Store the gathered information as required for the next iteration.
243+
prevTargetSize = targetSize;
244+
sourceDimIdx = sourceRange->rightIdx + 1;
245+
reassocRanges.push_back(*sourceRange);
79246
}
80-
// All the dimensions in the target must have been processed.
81-
if (reassociationMap.size() != targetShape.size())
247+
// Fail if the source shape wasn't a full match for the target shape. We only
248+
// need to check the last recorded index - any other gaps should have been
249+
// mended by the main loop.
250+
if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
251+
return failure();
252+
return reassocRanges;
253+
}
254+
255+
/// A variant of `findReassociationRangesForCollapse(...)` that can also scan
256+
/// the shapes right-to-left.
257+
static FailureOr<SmallVector<ReassociationIndexRange>>
258+
findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
259+
ArrayRef<int64_t> targetShape,
260+
bool iterateRightToLeft) {
261+
if (!iterateRightToLeft)
262+
return findReassociationRangesForCollapse(sourceShape, targetShape);
263+
// NB: To iterate right-to-left, we currently reverse the shapes and then
264+
// reverse the result back. The reversed shapes must not be temporary, as
265+
// we're passing through an ArrayRef.
266+
// FIXME: It would be preferable to avoid the expensive copies. At the moment,
267+
// this approach is chosen for readability of the main implementation.
268+
std::vector<int64_t> sourceToReverse = sourceShape.vec(),
269+
targetToReverse = targetShape.vec();
270+
std::reverse(sourceToReverse.begin(), sourceToReverse.end());
271+
std::reverse(targetToReverse.begin(), targetToReverse.end());
272+
auto invertedRanges =
273+
findReassociationRangesForCollapse(sourceToReverse, targetToReverse);
274+
if (failed(invertedRanges))
275+
return failure();
276+
SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
277+
unsigned numSourceDims = sourceShape.size();
278+
// We have received the ranges for inverted shapes. Now we have to invert
279+
// the ranges back to correspond with the original source shape.
280+
for (auto &range : rangesToInvert) {
281+
int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
282+
range.leftIdx = numSourceDims - 1 - invRightIdx;
283+
range.rightIdx = numSourceDims - 1 - invLeftIdx;
284+
}
285+
// Also invert the ordering of the ranges to correspond with the original
286+
// target shape.
287+
std::reverse(rangesToInvert.begin(), rangesToInvert.end());
288+
return rangesToInvert;
289+
}
290+
291+
std::optional<SmallVector<ReassociationIndices>>
292+
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
293+
ArrayRef<int64_t> targetShape) {
294+
unsigned numSourceDims = sourceShape.size(),
295+
numTargetDims = targetShape.size();
296+
// We're supposed to search for a collapsing reassociation. If the sizes
297+
// match, there's no actual collapsing taking place - it's either a no-op or a
298+
// `tensor.reshape`-style reassociation (that would be beyond the scope of
299+
// this utility).
300+
if (numSourceDims <= numTargetDims)
301+
return std::nullopt;
302+
// Early handling for scalar target types.
303+
if (numTargetDims == 0) {
304+
ReassociationIndices allSourceIndices;
305+
allSourceIndices.reserve(numSourceDims);
306+
for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
307+
++sourceDimIdx) {
308+
int64_t sourceSize = sourceShape[sourceDimIdx];
309+
// All source dimensions must be unit or dynamic.
310+
if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
311+
return std::nullopt;
312+
allSourceIndices.push_back(sourceDimIdx);
313+
}
314+
return SmallVector<ReassociationIndices>{allSourceIndices};
315+
}
316+
317+
// Collect source ranges by iterating over the target shape left-to-right.
318+
FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
319+
findReassociationRangesForCollapse(sourceShape, targetShape);
320+
if (failed(maybeForwardRanges))
321+
return std::nullopt;
322+
auto &ranges = *maybeForwardRanges;
323+
// Now do the same in reverse. We need to get another valid reassociation
324+
// through some other strategy, and then compare the results in order to
325+
// disambiguate mixed subshapes, such as:
326+
// ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
327+
// This leads us to lose some of the reassociation opportunities that can only
328+
// be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
329+
// backtracking, the algorithm will fail right-to-left. However, this is the
330+
// best way to preserve correctness.
331+
FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
332+
findReassociationRangesForCollapse(sourceShape, targetShape,
333+
/*iterateRightToLeft=*/true);
334+
if (failed(maybeReverseRanges))
335+
return std::nullopt;
336+
auto &reverseRanges = *maybeReverseRanges;
337+
338+
if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
82339
return std::nullopt;
83-
// Process any remaining entries in the source shape. They all need to be
84-
// 1 or dynamic.
85-
for (; sourceDim < sourceShape.size(); sourceDim++) {
86-
if (sourceShape[sourceDim] != ShapedType::kDynamic &&
87-
sourceShape[sourceDim] != 1)
88-
return std::nullopt;
89-
// The map is empty when the target type is a scalar.
90-
if (!reassociationMap.empty())
91-
reassociationMap.back().push_back(sourceDim);
340+
// Now we can check for ambiguity of each target dimension's reassociation. If
341+
// successful, we put the full indices into our result map for the target
342+
// shape.
343+
SmallVector<ReassociationIndices> reassociationMap(numTargetDims);
344+
for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
345+
++targetDimIdx) {
346+
ReassociationIndexRange &range = ranges[targetDimIdx];
347+
ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
348+
// Get non-overlapping indices between the ranges
349+
ReassociationIndices nonMatchingIndices =
350+
range.getNonOverlappingIndicesWith(reverseRange);
351+
// Unit dimensions can be collapsed wherever - this is the only ambiguity
352+
// that we allow.
353+
for (int64_t sourceDimIdx : nonMatchingIndices) {
354+
if (sourceShape[sourceDimIdx] != 1)
355+
return std::nullopt;
356+
}
357+
reassociationMap[targetDimIdx] = range.getFullIndices();
92358
}
93359
return reassociationMap;
94360
}

0 commit comments

Comments
 (0)