Skip to content

Commit 184127f

Browse files
committed
[fixup] apply non-functional comments
Signed-off-by: Artem Gindinson <[email protected]>
1 parent da46b10 commit 184127f

File tree

1 file changed

+16
-23
lines changed

1 file changed

+16
-23
lines changed

mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,9 @@ struct ReassociationIndexRange {
5959
}
6060
bool containsSingleIndex() const { return size() == 1; }
6161

62-
void expandRight() { ++rightIdx; }
63-
void shrinkLeft() { ++leftIdx; }
64-
65-
/// Implements arithmetic XOR semantics to get non-overlapping indices between
66-
/// ranges.
67-
ReassociationIndices operator^(ReassociationIndexRange &rhs) const {
62+
/// Collects indices that do not overlap between this and another range.
63+
ReassociationIndices
64+
getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {
6865
ReassociationIndices result;
6966
result.reserve(size() + rhs.size() / 2); // Attempt to amortize
7067
for (int64_t idx = this->leftIdx; idx <= this->rightIdx; ++idx) {
@@ -87,27 +84,26 @@ struct ReassociationIndexRange {
8784
return result;
8885
}
8986
};
87+
} // namespace
9088

9189
/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
9290
/// sequence that can be collapsed into a dynamic dimension (at least one must
9391
/// be present in the source).
9492
/// By default, lazily returns once the first dynamic dimension has been found.
9593
/// Setting `matchGreedily` as `true` will also mark all subsequent
9694
/// source dimensions for collapsing into the target.
97-
FailureOr<ReassociationIndexRange>
95+
static FailureOr<ReassociationIndexRange>
9896
findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
9997
int64_t sourceStartIdx,
10098
bool matchGreedily = false) {
10199
ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
102100
const unsigned numSourceDims = sourceShape.size();
103101
ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
104-
if (!iterationRange.isInRange(sourceShapeAsRange))
105-
return failure();
106102
auto resultRange = iterationRange;
107103

108104
bool foundDynamic = false;
109105
for (; iterationRange.isInRange(sourceShapeAsRange);
110-
iterationRange.expandRight()) {
106+
iterationRange.rightIdx++) {
111107
int64_t sourceSize = sourceShape[iterationRange.rightIdx];
112108
if (foundDynamic && !matchGreedily)
113109
break;
@@ -125,15 +121,13 @@ findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
125121
/// By default, lazily returns once the product matches the target size. Setting
126122
/// `matchGreedily` as `true` will append all neighboring unit dimensions
127123
/// (dimensions of 1) to the match.
128-
FailureOr<ReassociationIndexRange>
124+
static FailureOr<ReassociationIndexRange>
129125
findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
130126
int64_t sourceStartIdx, int64_t targetSize,
131127
bool matchGreedily = false) {
132128
ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
133129
const unsigned numSourceDims = sourceShape.size();
134130
ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
135-
if (!iterationRange.isInRange(sourceShapeAsRange))
136-
return failure();
137131
auto resultRange = iterationRange;
138132

139133
int64_t prodOfCollapsedDims = 1;
@@ -163,15 +157,16 @@ findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
163157
!iterationRange.containsSingleIndex()) {
164158
int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
165159
prodOfCollapsedDims /= frontSourceSize;
166-
iterationRange.shrinkLeft();
160+
// Shrink the range rightwards
161+
iterationRange.leftIdx++;
167162
}
168163
resultRange = iterationRange;
169164
// We could've reached the target size with the current dimension,
170165
// also as a result of the above shift to right.
171166
if (prodOfCollapsedDims == targetSize)
172167
reachedTargetDimSize = true;
173168
// Increment the iteration range
174-
iterationRange.expandRight();
169+
iterationRange.rightIdx++;
175170
}
176171
if (!reachedTargetDimSize)
177172
return failure();
@@ -191,7 +186,7 @@ findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
191186
/// linear complexity). As feasible, consider adding further backtracking
192187
/// routines to enable more reassociations, e.g.:
193188
/// - ?x2x?x2 into ?x2
194-
FailureOr<SmallVector<ReassociationIndexRange>>
189+
static FailureOr<SmallVector<ReassociationIndexRange>>
195190
findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
196191
ArrayRef<int64_t> targetShape) {
197192
unsigned numSourceDims = sourceShape.size(),
@@ -236,7 +231,7 @@ findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
236231
// Store the gathered information as required for the next iteration.
237232
prevTargetSize = targetSize;
238233
sourceDimIdx = sourceRange->rightIdx + 1;
239-
reassocRanges.emplace_back(std::move(*sourceRange));
234+
reassocRanges.push_back(*sourceRange);
240235
}
241236
// Fail if the source shape wasn't a full match for the target shape. We only
242237
// need to check the last recorded index - any other gaps should have been
@@ -248,7 +243,7 @@ findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
248243

249244
/// A variant of `findReassociationRangesForCollapse(...)` that can also scan
250245
/// the shapes right-to-left.
251-
FailureOr<SmallVector<ReassociationIndexRange>>
246+
static FailureOr<SmallVector<ReassociationIndexRange>>
252247
findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
253248
ArrayRef<int64_t> targetShape,
254249
bool iterateRightToLeft) {
@@ -268,8 +263,6 @@ findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
268263
// We have received the ranges for inverted shapes. Now we have to invert
269264
// the ranges back to correspond with the original source shape.
270265
for (auto &range : rangesToInvert) {
271-
if (failed(range.verify()))
272-
return failure();
273266
int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
274267
range.leftIdx = numSourceDims - 1 - invRightIdx;
275268
range.rightIdx = numSourceDims - 1 - invLeftIdx;
@@ -279,7 +272,6 @@ findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
279272
std::reverse(rangesToInvert.begin(), rangesToInvert.end());
280273
return rangesToInvert;
281274
}
282-
} // namespace
283275

284276
std::optional<SmallVector<ReassociationIndices>>
285277
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
@@ -298,7 +290,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
298290
// All source dimensions must be unit or dynamic.
299291
if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
300292
return std::nullopt;
301-
allSourceIndices.emplace_back(sourceDimIdx);
293+
allSourceIndices.push_back(sourceDimIdx);
302294
}
303295
return SmallVector<ReassociationIndices>{allSourceIndices};
304296
}
@@ -337,7 +329,8 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
337329
auto &range = ranges[targetDimIdx];
338330
auto &reverseRange = reverseRanges[targetDimIdx];
339331
// Get non-overlapping indices between the ranges
340-
ReassociationIndices nonMatchingIndices = range ^ reverseRange;
332+
ReassociationIndices nonMatchingIndices =
333+
range.getNonOverlappingIndicesWith(reverseRange);
341334
// Unit dimensions can be collapsed wherever - this is the only ambiguity
342335
// that we allow.
343336
for (int64_t sourceDimIdx : nonMatchingIndices) {

0 commit comments

Comments
 (0)