Skip to content

Commit e0b5f0d

Browse files
committed
[fixup] improve getNonOverlappingIndicesWith(&rhs)
Signed-off-by: Artem Gindinson <[email protected]>
1 parent 7f31389 commit e0b5f0d

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,23 @@ struct ReassociationIndexRange {
6262
/// Collects indices that do not overlap between this and another range.
6363
ReassociationIndices
6464
getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {
65-
ReassociationIndices result;
66-
result.reserve(size() + rhs.size() / 2); // Attempt to amortize
67-
for (int64_t idx = this->leftIdx; idx <= this->rightIdx; ++idx) {
68-
if (idx < rhs.leftIdx || idx > rhs.rightIdx)
69-
result.push_back(idx);
70-
}
71-
for (int64_t rhsIndex = rhs.leftIdx; rhsIndex <= rhs.rightIdx; ++rhsIndex) {
72-
if (rhsIndex < leftIdx || rhsIndex > rightIdx)
73-
result.push_back(rhsIndex);
65+
if (rightIdx < rhs.leftIdx) {
66+
// The intervals do not overlap - concatenate the indices from both.
67+
auto jointFullIndices = getFullIndices();
68+
jointFullIndices.append(rhs.getFullIndices());
69+
return jointFullIndices;
7470
}
71+
ReassociationIndices result;
72+
// Handle the chunk left of the overlapping range.
73+
int64_t leftStart = std::min(leftIdx, rhs.leftIdx);
74+
int64_t leftEnd = std::max(leftIdx, rhs.leftIdx);
75+
llvm::append_range(result, llvm::seq(leftStart, leftEnd));
76+
// Handle the chunk right of the overlapping range. Symmetrically, we should
77+
// skip the edge of the overlap AND include the rightmost index.
78+
int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1;
79+
int64_t rightEnd = std::max(rightIdx, rhs.rightIdx);
80+
if (rightStart < rightEnd)
81+
llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));
7582
return result;
7683
}
7784

0 commit comments

Comments
 (0)