@@ -59,12 +59,9 @@ struct ReassociationIndexRange {
59
59
}
60
60
bool containsSingleIndex () const { return size () == 1 ; }
61
61
62
- void expandRight () { ++rightIdx; }
63
- void shrinkLeft () { ++leftIdx; }
64
-
65
- // / Implements arithmetic XOR semantics to get non-overlapping indices between
66
- // / ranges.
67
- ReassociationIndices operator ^(ReassociationIndexRange &rhs) const {
62
+ // / Collects indices that do not overlap between this and another range.
63
+ ReassociationIndices
64
+ getNonOverlappingIndicesWith (ReassociationIndexRange &rhs) const {
68
65
ReassociationIndices result;
69
66
result.reserve (size () + rhs.size () / 2 ); // Attempt to amortize
70
67
for (int64_t idx = this ->leftIdx ; idx <= this ->rightIdx ; ++idx) {
@@ -87,27 +84,26 @@ struct ReassociationIndexRange {
87
84
return result;
88
85
}
89
86
};
87
+ } // namespace
90
88
91
89
// / Starting from `sourceStartIdx`, searches `sourceShape` for the first
92
90
// / sequence that can be collapsed into a dynamic dimension (at least one must
93
91
// / be present in the source).
94
92
// / By default, lazily returns once the first dynamic dimension has been found.
95
93
// / Setting `matchGreedily` as `true` will also mark all subsequent
96
94
// / source dimensions for collapsing into the target.
97
- FailureOr<ReassociationIndexRange>
95
+ static FailureOr<ReassociationIndexRange>
98
96
findReassociationRangeForDynamicDim (ArrayRef<int64_t > sourceShape,
99
97
int64_t sourceStartIdx,
100
98
bool matchGreedily = false ) {
101
99
ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
102
100
const unsigned numSourceDims = sourceShape.size ();
103
101
ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
104
- if (!iterationRange.isInRange (sourceShapeAsRange))
105
- return failure ();
106
102
auto resultRange = iterationRange;
107
103
108
104
bool foundDynamic = false ;
109
105
for (; iterationRange.isInRange (sourceShapeAsRange);
110
- iterationRange.expandRight () ) {
106
+ iterationRange.rightIdx ++ ) {
111
107
int64_t sourceSize = sourceShape[iterationRange.rightIdx ];
112
108
if (foundDynamic && !matchGreedily)
113
109
break ;
@@ -125,15 +121,13 @@ findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
125
121
// / By default, lazily returns once the product matches the target size. Setting
126
122
// / `matchGreedily` as `true` will append all neighboring unit dimensions
127
123
// / (dimensions of 1) to the match.
128
- FailureOr<ReassociationIndexRange>
124
+ static FailureOr<ReassociationIndexRange>
129
125
findReassociationRangeForSize (ArrayRef<int64_t > sourceShape,
130
126
int64_t sourceStartIdx, int64_t targetSize,
131
127
bool matchGreedily = false ) {
132
128
ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
133
129
const unsigned numSourceDims = sourceShape.size ();
134
130
ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
135
- if (!iterationRange.isInRange (sourceShapeAsRange))
136
- return failure ();
137
131
auto resultRange = iterationRange;
138
132
139
133
int64_t prodOfCollapsedDims = 1 ;
@@ -163,15 +157,16 @@ findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
163
157
!iterationRange.containsSingleIndex ()) {
164
158
int64_t frontSourceSize = sourceShape[iterationRange.leftIdx ];
165
159
prodOfCollapsedDims /= frontSourceSize;
166
- iterationRange.shrinkLeft ();
160
+ // Shrink the range rightwards
161
+ iterationRange.leftIdx ++;
167
162
}
168
163
resultRange = iterationRange;
169
164
// We could've reached the target size with the current dimension,
170
165
// also as a result of the above shift to right.
171
166
if (prodOfCollapsedDims == targetSize)
172
167
reachedTargetDimSize = true ;
173
168
// Increment the iteration range
174
- iterationRange.expandRight () ;
169
+ iterationRange.rightIdx ++ ;
175
170
}
176
171
if (!reachedTargetDimSize)
177
172
return failure ();
@@ -191,7 +186,7 @@ findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
191
186
// / linear complexity). As feasible, consider adding further backtracking
192
187
// / routines to enable more reassociations, e.g.:
193
188
// / - ?x2x?x2 into ?x2
194
- FailureOr<SmallVector<ReassociationIndexRange>>
189
+ static FailureOr<SmallVector<ReassociationIndexRange>>
195
190
findReassociationRangesForCollapse (ArrayRef<int64_t > sourceShape,
196
191
ArrayRef<int64_t > targetShape) {
197
192
unsigned numSourceDims = sourceShape.size (),
@@ -236,7 +231,7 @@ findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
236
231
// Store the gathered information as required for the next iteration.
237
232
prevTargetSize = targetSize;
238
233
sourceDimIdx = sourceRange->rightIdx + 1 ;
239
- reassocRanges.emplace_back ( std::move ( *sourceRange) );
234
+ reassocRanges.push_back ( *sourceRange);
240
235
}
241
236
// Fail if the source shape wasn't a full match for the target shape. We only
242
237
// need to check the last recorded index - any other gaps should have been
@@ -248,7 +243,7 @@ findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
248
243
249
244
// / A variant of `findReassociationRangesForCollapse(...)` that can also scan
250
245
// / the shapes right-to-left.
251
- FailureOr<SmallVector<ReassociationIndexRange>>
246
+ static FailureOr<SmallVector<ReassociationIndexRange>>
252
247
findReassociationRangesForCollapse (ArrayRef<int64_t > sourceShape,
253
248
ArrayRef<int64_t > targetShape,
254
249
bool iterateRightToLeft) {
@@ -268,8 +263,6 @@ findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
268
263
// We have received the ranges for inverted shapes. Now we have to invert
269
264
// the ranges back to correspond with the original source shape.
270
265
for (auto &range : rangesToInvert) {
271
- if (failed (range.verify ()))
272
- return failure ();
273
266
int64_t invLeftIdx = range.leftIdx , invRightIdx = range.rightIdx ;
274
267
range.leftIdx = numSourceDims - 1 - invRightIdx;
275
268
range.rightIdx = numSourceDims - 1 - invLeftIdx;
@@ -279,7 +272,6 @@ findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
279
272
std::reverse (rangesToInvert.begin (), rangesToInvert.end ());
280
273
return rangesToInvert;
281
274
}
282
- } // namespace
283
275
284
276
std::optional<SmallVector<ReassociationIndices>>
285
277
mlir::getReassociationIndicesForCollapse (ArrayRef<int64_t > sourceShape,
@@ -298,7 +290,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
298
290
// All source dimensions must be unit or dynamic.
299
291
if (sourceSize != 1 && sourceSize != ShapedType::kDynamic )
300
292
return std::nullopt;
301
- allSourceIndices.emplace_back (sourceDimIdx);
293
+ allSourceIndices.push_back (sourceDimIdx);
302
294
}
303
295
return SmallVector<ReassociationIndices>{allSourceIndices};
304
296
}
@@ -337,7 +329,8 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
337
329
auto &range = ranges[targetDimIdx];
338
330
auto &reverseRange = reverseRanges[targetDimIdx];
339
331
// Get non-overlapping indices between the ranges
340
- ReassociationIndices nonMatchingIndices = range ^ reverseRange;
332
+ ReassociationIndices nonMatchingIndices =
333
+ range.getNonOverlappingIndicesWith (reverseRange);
341
334
// Unit dimensions can be collapsed wherever - this is the only ambiguity
342
335
// that we allow.
343
336
for (int64_t sourceDimIdx : nonMatchingIndices) {
0 commit comments