Skip to content

Commit f5a2f00

Browse files
authored
Revert "[mlir][tensor] Loosen restrictions on folding dynamic reshapes" (#142639)
Reverts #137963 --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 1340ecf commit f5a2f00

File tree

5 files changed

+59
-560
lines changed

5 files changed

+59
-560
lines changed

mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Lines changed: 53 additions & 319 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@
1010

1111
#include "mlir/IR/AffineMap.h"
1212
#include "mlir/IR/Builders.h"
13-
#include "mlir/IR/BuiltinTypeInterfaces.h"
14-
#include "llvm/ADT/ArrayRef.h"
15-
#include "llvm/ADT/SmallVector.h"
16-
#include "llvm/Support/LogicalResult.h"
1713

1814
#include <numeric>
1915
#include <optional>
@@ -32,329 +28,67 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
3228
return std::nullopt;
3329
}
3430

35-
namespace {
36-
/// A simple struct to represent ReassociationIndices as an inclusive interval.
37-
/// It's designed to be feasibly minimal, so the call sites should manage the
38-
/// validity of the range manually.
39-
struct ReassociationIndexRange {
40-
/// FIXME: Signed type is used for consistency with ReassociationIndices.
41-
/// We should consider refactoring all reassociation utilities to use unsigned
42-
/// types.
43-
int64_t leftIdx = 0, rightIdx = 0;
44-
45-
/// Util for manual checks of the range's validity
46-
LogicalResult verify() const {
47-
return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
48-
}
49-
50-
/// Checks range's containment within another range. Treats the edges
51-
/// non-exclusively.
52-
bool isInRange(const ReassociationIndexRange &outerRange) const {
53-
return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
54-
}
55-
56-
unsigned size() const {
57-
assert(succeeded(verify()));
58-
return rightIdx - leftIdx + 1;
59-
}
60-
bool containsSingleIndex() const { return size() == 1; }
61-
62-
/// Collects indices that do not overlap between this and another range.
63-
ReassociationIndices
64-
getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {
65-
if (rightIdx < rhs.leftIdx) {
66-
// The intervals do not overlap - concatenate the indices from both.
67-
auto jointFullIndices = getFullIndices();
68-
jointFullIndices.append(rhs.getFullIndices());
69-
return jointFullIndices;
70-
}
71-
ReassociationIndices result;
72-
// Handle the chunk left of the overlapping range.
73-
int64_t leftStart = std::min(leftIdx, rhs.leftIdx);
74-
int64_t leftEnd = std::max(leftIdx, rhs.leftIdx);
75-
llvm::append_range(result, llvm::seq(leftStart, leftEnd));
76-
// Handle the chunk right of the overlapping range. Symmetrically, we should
77-
// skip the edge of the overlap AND include the rightmost index.
78-
int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1;
79-
int64_t rightEnd = std::max(rightIdx, rhs.rightIdx);
80-
if (rightStart < rightEnd)
81-
llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));
82-
return result;
83-
}
84-
85-
/// Converts the range into ReassociationIndices.
86-
ReassociationIndices getFullIndices() const {
87-
ReassociationIndices result;
88-
for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
89-
result.push_back(idx);
90-
}
91-
return result;
92-
}
93-
};
94-
} // namespace
95-
96-
/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
97-
/// sequence that can be collapsed into a dynamic dimension (at least one must
98-
/// be present in the source).
99-
/// By default, lazily returns once the first dynamic dimension has been found.
100-
/// Setting `matchGreedily` as `true` will also mark all subsequent
101-
/// source dimensions for collapsing into the target.
102-
static FailureOr<ReassociationIndexRange>
103-
findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
104-
int64_t sourceStartIdx,
105-
bool matchGreedily = false) {
106-
const unsigned numSourceDims = sourceShape.size();
107-
ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
108-
std::optional<ReassociationIndexRange> resultRange = std::nullopt;
109-
110-
ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
111-
for (; iterationRange.isInRange(sourceShapeAsRange);
112-
iterationRange.rightIdx++) {
113-
int64_t sourceSize = sourceShape[iterationRange.rightIdx];
114-
if (sourceSize == ShapedType::kDynamic) {
115-
resultRange = iterationRange;
116-
break;
117-
}
118-
}
119-
if (!resultRange)
120-
return failure();
121-
if (matchGreedily)
122-
resultRange->rightIdx = sourceShapeAsRange.rightIdx;
123-
return *resultRange;
124-
}
31+
std::optional<SmallVector<ReassociationIndices>>
32+
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
33+
ArrayRef<int64_t> targetShape) {
34+
if (sourceShape.size() <= targetShape.size())
35+
return std::nullopt;
36+
unsigned sourceDim = 0;
37+
SmallVector<ReassociationIndices> reassociationMap;
38+
reassociationMap.reserve(targetShape.size());
12539

126-
/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
127-
/// sequence of static dimensions such that their product matches `targetSize`.
128-
/// By default, lazily returns once the product matches the target size. Setting
129-
/// `matchGreedily` as `true` will append all neighboring unit dimensions
130-
/// (dimensions of 1) to the match.
131-
static FailureOr<ReassociationIndexRange>
132-
findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
133-
int64_t sourceStartIdx, int64_t targetSize,
134-
bool matchGreedily = false) {
135-
const unsigned numSourceDims = sourceShape.size();
136-
ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
137-
std::optional<ReassociationIndexRange> resultRange = std::nullopt;
138-
139-
ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
40+
ReassociationIndices currIndices;
14041
int64_t prodOfCollapsedDims = 1;
141-
while (iterationRange.isInRange(sourceShapeAsRange)) {
142-
int64_t sourceSize = sourceShape[iterationRange.rightIdx];
143-
if (sourceSize == ShapedType::kDynamic) {
144-
// Reassociation for a static dim cannot include a dynamic dim. Reset
145-
// induction variables to essentially restart the loop from the next
146-
// source dimension.
147-
prodOfCollapsedDims = 1;
148-
iterationRange = {iterationRange.rightIdx + 1,
149-
iterationRange.rightIdx + 1};
150-
continue;
151-
}
152-
prodOfCollapsedDims *= sourceSize;
153-
// If the target size has been exceeded without matching, we need to shift
154-
// the range start right. From the start of the range, roll back the
155-
// multiplication until the target size exceeds the product again.
156-
while (prodOfCollapsedDims > targetSize &&
157-
!iterationRange.containsSingleIndex()) {
158-
int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
159-
prodOfCollapsedDims /= frontSourceSize;
160-
// Shrink the range rightwards
161-
iterationRange.leftIdx++;
162-
}
163-
// We could've reached the target size with the current dimension,
164-
// also as a result of the above shift to right.
165-
if (prodOfCollapsedDims == targetSize) {
166-
resultRange = iterationRange;
42+
while (sourceDim < sourceShape.size()) {
43+
unsigned targetDim = reassociationMap.size();
44+
// If we have mapped all the target dimensions stop and handle the remaining
45+
// tail of size-1 dimensions explicitly.
46+
if (targetDim == targetShape.size())
16747
break;
168-
}
169-
// Increment the iteration range
170-
iterationRange.rightIdx++;
171-
}
172-
if (!resultRange)
173-
return failure();
174-
if (matchGreedily) {
175-
// We now want to collect all unit dimensions directly after the target
176-
// product match. Advance the iterator to avoid OOB when the product match
177-
// happens at the last element.
178-
iterationRange.rightIdx++;
179-
while (iterationRange.isInRange(sourceShapeAsRange) &&
180-
sourceShape[iterationRange.rightIdx] == 1) {
181-
resultRange = iterationRange;
182-
iterationRange.rightIdx++;
183-
}
184-
}
185-
return *resultRange;
186-
}
18748

188-
/// Attempts to find a valid collapsing reassociation of `sourceShape` into
189-
/// `targetShape` through a simple traversal. If successful, an array of source
190-
/// index ranges is returned, correspondingly to each dimension in the target
191-
/// shape. The resulting indices shall fully cover the `sourceShape` without
192-
/// overlaps.
193-
///
194-
/// The algorithm is essentially a lazy one, searching for non-greedy matches -
195-
/// it will only yield a greedy match for the last target dimension.
196-
/// FIXME: The algorithm can only backtrack when it needs to append an offset
197-
/// for a static target dimension to the preceding dynamic one (this retains the
198-
/// linear complexity). As feasible, consider adding further backtracking
199-
/// routines to enable more reassociations, e.g.:
200-
/// - ?x2x?x2 into ?x2
201-
static FailureOr<SmallVector<ReassociationIndexRange>>
202-
findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
203-
ArrayRef<int64_t> targetShape) {
204-
unsigned numSourceDims = sourceShape.size(),
205-
numTargetDims = targetShape.size();
206-
assert(numSourceDims > numTargetDims);
207-
ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
208-
209-
SmallVector<ReassociationIndexRange> reassocRanges;
210-
reassocRanges.reserve(numTargetDims);
211-
// We'll iterate in strides of 2 to enable pseudo-backtracking for simple
212-
// cases, e.g.:
213-
// - ?x2x3x5 into ?x15
214-
std::optional<int64_t> prevTargetSize = std::nullopt;
215-
for (unsigned targetDimIdx = 0, sourceDimIdx = 0;
216-
targetDimIdx < numTargetDims; ++targetDimIdx) {
217-
int64_t targetSize = targetShape[targetDimIdx];
218-
// Simply check if there are any subsequent target dimensions left - if not,
219-
// the match must be made greedily.
220-
bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;
221-
FailureOr<ReassociationIndexRange> sourceRange;
222-
if (targetSize == ShapedType::kDynamic) {
223-
sourceRange = findReassociationRangeForDynamicDim(
224-
sourceShape, sourceDimIdx, shouldMatchGreedily);
225-
} else {
226-
sourceRange = findReassociationRangeForSize(
227-
sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
49+
int64_t currTargetShape = targetShape[targetDim];
50+
while (sourceDim < (sourceShape.size() - 1) &&
51+
sourceShape[sourceDim] != ShapedType::kDynamic &&
52+
prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
53+
prodOfCollapsedDims *= sourceShape[sourceDim];
54+
currIndices.push_back(sourceDim++);
22855
}
22956

230-
// Run sanity checks on the returned index range.
231-
if (failed(sourceRange) || failed(sourceRange->verify()) ||
232-
!sourceRange->isInRange(sourceShapeAsRange))
233-
return failure();
234-
if (sourceRange->leftIdx > sourceDimIdx) {
235-
// If some source dimensions had to be skipped in order to find a match,
236-
// they must be collapsed into the directly preceding dynamic dimension.
237-
if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
238-
return failure();
239-
reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
240-
}
241-
242-
// Store the gathered information as required for the next iteration.
243-
prevTargetSize = targetSize;
244-
sourceDimIdx = sourceRange->rightIdx + 1;
245-
reassocRanges.push_back(*sourceRange);
57+
// If the current expanded dimension is dynamic, then the collapsed
58+
// dimensions should also be dynamic and product of all previous unprocessed
59+
// dimensions of the expanded shape should be 1.
60+
if (sourceShape[sourceDim] == ShapedType::kDynamic &&
61+
(currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
62+
return std::nullopt;
63+
64+
// If the collapsed dim is dynamic, the current expanded dim should also
65+
// be dynamic.
66+
if (currTargetShape == ShapedType::kDynamic &&
67+
sourceShape[sourceDim] != ShapedType::kDynamic)
68+
return std::nullopt;
69+
70+
// For static shapes, if the product of dimensions of the expanded shape
71+
// should match the collapsed dimension shape.
72+
if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
73+
return std::nullopt;
74+
75+
currIndices.push_back(sourceDim++);
76+
reassociationMap.emplace_back(ReassociationIndices{});
77+
std::swap(reassociationMap.back(), currIndices);
78+
prodOfCollapsedDims = 1;
24679
}
247-
// Fail if the source shape wasn't a full match for the target shape. We only
248-
// need to check the last recorded index - any other gaps should have been
249-
// mended by the main loop.
250-
if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
251-
return failure();
252-
return reassocRanges;
253-
}
254-
255-
/// A variant of `findReassociationRangesForCollapse(...)` that can also scan
256-
/// the shapes right-to-left.
257-
static FailureOr<SmallVector<ReassociationIndexRange>>
258-
findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
259-
ArrayRef<int64_t> targetShape,
260-
bool iterateRightToLeft) {
261-
if (!iterateRightToLeft)
262-
return findReassociationRangesForCollapse(sourceShape, targetShape);
263-
// NB: To iterate right-to-left, we currently reverse the shapes and then
264-
// reverse the result back. The reversed shapes must not be temporary, as
265-
// we're passing through an ArrayRef.
266-
// FIXME: It would be preferable to avoid the expensive copies. At the moment,
267-
// this approach is chosen for readability of the main implementation.
268-
std::vector<int64_t> sourceToReverse = sourceShape.vec(),
269-
targetToReverse = targetShape.vec();
270-
std::reverse(sourceToReverse.begin(), sourceToReverse.end());
271-
std::reverse(targetToReverse.begin(), targetToReverse.end());
272-
auto invertedRanges =
273-
findReassociationRangesForCollapse(sourceToReverse, targetToReverse);
274-
if (failed(invertedRanges))
275-
return failure();
276-
SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
277-
unsigned numSourceDims = sourceShape.size();
278-
// We have received the ranges for inverted shapes. Now we have to invert
279-
// the ranges back to correspond with the original source shape.
280-
for (auto &range : rangesToInvert) {
281-
int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
282-
range.leftIdx = numSourceDims - 1 - invRightIdx;
283-
range.rightIdx = numSourceDims - 1 - invLeftIdx;
284-
}
285-
// Also invert the ordering of the ranges to correspond with the original
286-
// target shape.
287-
std::reverse(rangesToInvert.begin(), rangesToInvert.end());
288-
return rangesToInvert;
289-
}
290-
291-
std::optional<SmallVector<ReassociationIndices>>
292-
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
293-
ArrayRef<int64_t> targetShape) {
294-
unsigned numSourceDims = sourceShape.size(),
295-
numTargetDims = targetShape.size();
296-
// We're supposed to search for a collapsing reassociation. If the sizes
297-
// match, there's no actual collapsing taking place - it's either a no-op or a
298-
// `tensor.reshape`-style reassociation (that would be beyond the scope of
299-
// this utility).
300-
if (numSourceDims <= numTargetDims)
301-
return std::nullopt;
302-
// Early handling for scalar target types.
303-
if (numTargetDims == 0) {
304-
ReassociationIndices allSourceIndices;
305-
allSourceIndices.reserve(numSourceDims);
306-
for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
307-
++sourceDimIdx) {
308-
int64_t sourceSize = sourceShape[sourceDimIdx];
309-
// All source dimensions must be unit or dynamic.
310-
if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
311-
return std::nullopt;
312-
allSourceIndices.push_back(sourceDimIdx);
313-
}
314-
return SmallVector<ReassociationIndices>{allSourceIndices};
315-
}
316-
317-
// Collect source ranges by iterating over the target shape left-to-right.
318-
FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
319-
findReassociationRangesForCollapse(sourceShape, targetShape);
320-
if (failed(maybeForwardRanges))
321-
return std::nullopt;
322-
auto &ranges = *maybeForwardRanges;
323-
// Now do the same in reverse. We need to get another valid reassociation
324-
// through some other strategy, and then compare the results in order to
325-
// disambiguate mixed subshapes, such as:
326-
// ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
327-
// This leads us to lose some of the reassociation opportunities that can only
328-
// be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
329-
// backtracking, the algorithm will fail right-to-left. However, this is the
330-
// best way to preserve correctness.
331-
FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
332-
findReassociationRangesForCollapse(sourceShape, targetShape,
333-
/*iterateRightToLeft=*/true);
334-
if (failed(maybeReverseRanges))
335-
return std::nullopt;
336-
auto &reverseRanges = *maybeReverseRanges;
337-
338-
if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
80+
// All the dimensions in the target must have been processed.
81+
if (reassociationMap.size() != targetShape.size())
33982
return std::nullopt;
340-
// Now we can check for ambiguity of each target dimension's reassociation. If
341-
// successful, we put the full indices into our result map for the target
342-
// shape.
343-
SmallVector<ReassociationIndices> reassociationMap(numTargetDims);
344-
for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
345-
++targetDimIdx) {
346-
ReassociationIndexRange &range = ranges[targetDimIdx];
347-
ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
348-
// Get non-overlapping indices between the ranges
349-
ReassociationIndices nonMatchingIndices =
350-
range.getNonOverlappingIndicesWith(reverseRange);
351-
// Unit dimensions can be collapsed wherever - this is the only ambiguity
352-
// that we allow.
353-
for (int64_t sourceDimIdx : nonMatchingIndices) {
354-
if (sourceShape[sourceDimIdx] != 1)
355-
return std::nullopt;
356-
}
357-
reassociationMap[targetDimIdx] = range.getFullIndices();
83+
// Process any remaining entries in the source shape. They all need to be
84+
// 1 or dynamic.
85+
for (; sourceDim < sourceShape.size(); sourceDim++) {
86+
if (sourceShape[sourceDim] != ShapedType::kDynamic &&
87+
sourceShape[sourceDim] != 1)
88+
return std::nullopt;
89+
// The map is empty when the target type is a scalar.
90+
if (!reassociationMap.empty())
91+
reassociationMap.back().push_back(sourceDim);
35892
}
35993
return reassociationMap;
36094
}

0 commit comments

Comments
 (0)