Skip to content

Commit 40e1e29

Browse files
committed
[fixup] apply greedy logic suggestions
Signed-off-by: Artem Gindinson <[email protected]>
1 parent 184127f commit 40e1e29

File tree

1 file changed

+30
-25
lines changed

1 file changed

+30
-25
lines changed

mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -96,24 +96,24 @@ static FailureOr<ReassociationIndexRange>
9696
findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
9797
int64_t sourceStartIdx,
9898
bool matchGreedily = false) {
99-
ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
10099
const unsigned numSourceDims = sourceShape.size();
101100
ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
102-
auto resultRange = iterationRange;
101+
std::optional<ReassociationIndexRange> resultRange = std::nullopt;
103102

104-
bool foundDynamic = false;
103+
ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
105104
for (; iterationRange.isInRange(sourceShapeAsRange);
106105
iterationRange.rightIdx++) {
107106
int64_t sourceSize = sourceShape[iterationRange.rightIdx];
108-
if (foundDynamic && !matchGreedily)
107+
if (sourceSize == ShapedType::kDynamic) {
108+
resultRange = iterationRange;
109109
break;
110-
if (sourceSize == ShapedType::kDynamic)
111-
foundDynamic = true;
112-
resultRange = iterationRange;
110+
}
113111
}
114-
if (!foundDynamic)
112+
if (!resultRange)
115113
return failure();
116-
return resultRange;
114+
if (matchGreedily)
115+
resultRange->rightIdx = sourceShapeAsRange.rightIdx;
116+
return *resultRange;
117117
}
118118

119119
/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
@@ -125,31 +125,24 @@ static FailureOr<ReassociationIndexRange>
125125
findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
126126
int64_t sourceStartIdx, int64_t targetSize,
127127
bool matchGreedily = false) {
128-
ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
129128
const unsigned numSourceDims = sourceShape.size();
130129
ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
131-
auto resultRange = iterationRange;
130+
std::optional<ReassociationIndexRange> resultRange = std::nullopt;
132131

132+
ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
133133
int64_t prodOfCollapsedDims = 1;
134-
bool reachedTargetDimSize = false;
135134
while (iterationRange.isInRange(sourceShapeAsRange)) {
136135
int64_t sourceSize = sourceShape[iterationRange.rightIdx];
137-
if (reachedTargetDimSize && !matchGreedily)
138-
break;
139136
if (sourceSize == ShapedType::kDynamic) {
140-
if (reachedTargetDimSize)
141-
break;
142137
// Reassociation for a static dim cannot include a dynamic dim. Reset
143138
// induction variables to essentially restart the loop from the next
144139
// source dimension.
145140
prodOfCollapsedDims = 1;
146-
resultRange = {iterationRange.rightIdx + 1, iterationRange.rightIdx + 1};
147-
iterationRange = resultRange;
141+
iterationRange = {iterationRange.rightIdx + 1,
142+
iterationRange.rightIdx + 1};
148143
continue;
149144
}
150145
prodOfCollapsedDims *= sourceSize;
151-
if (prodOfCollapsedDims > targetSize && reachedTargetDimSize)
152-
break;
153146
// If the target size has been exceeded without matching, we need to shift
154147
// the range start right. From the start of the range, roll back the
155148
// multiplication until the target size exceeds the product again.
@@ -160,17 +153,29 @@ findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
160153
// Shrink the range rightwards
161154
iterationRange.leftIdx++;
162155
}
163-
resultRange = iterationRange;
164156
// We could've reached the target size with the current dimension,
165157
// also as a result of the above shift to right.
166-
if (prodOfCollapsedDims == targetSize)
167-
reachedTargetDimSize = true;
158+
if (prodOfCollapsedDims == targetSize) {
159+
resultRange = iterationRange;
160+
break;
161+
}
168162
// Increment the iteration range
169163
iterationRange.rightIdx++;
170164
}
171-
if (!reachedTargetDimSize)
165+
if (!resultRange)
172166
return failure();
173-
return resultRange;
167+
if (matchGreedily) {
168+
// We now want to collect all unit dimensions directly after the target
169+
// product match. Advance the iterator to avoid OOB when the product match
170+
// happens at the last element.
171+
iterationRange.rightIdx++;
172+
while (iterationRange.isInRange(sourceShapeAsRange) &&
173+
sourceShape[iterationRange.rightIdx] == 1) {
174+
resultRange = iterationRange;
175+
iterationRange.rightIdx++;
176+
}
177+
}
178+
return *resultRange;
174179
}
175180

176181
/// Attempts to find a valid collapsing reassociation of `sourceShape` into

0 commit comments

Comments
 (0)