@@ -96,24 +96,24 @@ static FailureOr<ReassociationIndexRange>
96
96
findReassociationRangeForDynamicDim (ArrayRef<int64_t > sourceShape,
97
97
int64_t sourceStartIdx,
98
98
bool matchGreedily = false ) {
99
- ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
100
99
const unsigned numSourceDims = sourceShape.size ();
101
100
ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
102
- auto resultRange = iterationRange ;
101
+ std::optional<ReassociationIndexRange> resultRange = std::nullopt ;
103
102
104
- bool foundDynamic = false ;
103
+ ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx} ;
105
104
for (; iterationRange.isInRange (sourceShapeAsRange);
106
105
iterationRange.rightIdx ++) {
107
106
int64_t sourceSize = sourceShape[iterationRange.rightIdx ];
108
- if (foundDynamic && !matchGreedily)
107
+ if (sourceSize == ShapedType::kDynamic ) {
108
+ resultRange = iterationRange;
109
109
break ;
110
- if (sourceSize == ShapedType::kDynamic )
111
- foundDynamic = true ;
112
- resultRange = iterationRange;
110
+ }
113
111
}
114
- if (!foundDynamic )
112
+ if (!resultRange )
115
113
return failure ();
116
- return resultRange;
114
+ if (matchGreedily)
115
+ resultRange->rightIdx = sourceShapeAsRange.rightIdx ;
116
+ return *resultRange;
117
117
}
118
118
119
119
// / Starting from `sourceStartIdx`, searches `sourceShape` for the first
@@ -125,31 +125,24 @@ static FailureOr<ReassociationIndexRange>
125
125
findReassociationRangeForSize (ArrayRef<int64_t > sourceShape,
126
126
int64_t sourceStartIdx, int64_t targetSize,
127
127
bool matchGreedily = false ) {
128
- ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
129
128
const unsigned numSourceDims = sourceShape.size ();
130
129
ReassociationIndexRange sourceShapeAsRange{0 , numSourceDims - 1 };
131
- auto resultRange = iterationRange ;
130
+ std::optional<ReassociationIndexRange> resultRange = std::nullopt ;
132
131
132
+ ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
133
133
int64_t prodOfCollapsedDims = 1 ;
134
- bool reachedTargetDimSize = false ;
135
134
while (iterationRange.isInRange (sourceShapeAsRange)) {
136
135
int64_t sourceSize = sourceShape[iterationRange.rightIdx ];
137
- if (reachedTargetDimSize && !matchGreedily)
138
- break ;
139
136
if (sourceSize == ShapedType::kDynamic ) {
140
- if (reachedTargetDimSize)
141
- break ;
142
137
// Reassociation for a static dim cannot include a dynamic dim. Reset
143
138
// induction variables to essentially restart the loop from the next
144
139
// source dimension.
145
140
prodOfCollapsedDims = 1 ;
146
- resultRange = {iterationRange.rightIdx + 1 , iterationRange. rightIdx + 1 };
147
- iterationRange = resultRange ;
141
+ iterationRange = {iterationRange.rightIdx + 1 ,
142
+ iterationRange. rightIdx + 1 } ;
148
143
continue ;
149
144
}
150
145
prodOfCollapsedDims *= sourceSize;
151
- if (prodOfCollapsedDims > targetSize && reachedTargetDimSize)
152
- break ;
153
146
// If the target size has been exceeded without matching, we need to shift
154
147
// the range start right. From the start of the range, roll back the
155
148
// multiplication until the target size exceeds the product again.
@@ -160,17 +153,29 @@ findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
160
153
// Shrink the range rightwards
161
154
iterationRange.leftIdx ++;
162
155
}
163
- resultRange = iterationRange;
164
156
// We could've reached the target size with the current dimension,
165
157
// also as a result of the above shift to right.
166
- if (prodOfCollapsedDims == targetSize)
167
- reachedTargetDimSize = true ;
158
+ if (prodOfCollapsedDims == targetSize) {
159
+ resultRange = iterationRange;
160
+ break ;
161
+ }
168
162
// Increment the iteration range
169
163
iterationRange.rightIdx ++;
170
164
}
171
- if (!reachedTargetDimSize )
165
+ if (!resultRange )
172
166
return failure ();
173
- return resultRange;
167
+ if (matchGreedily) {
168
+ // We now want to collect all unit dimensions directly after the target
169
+ // product match. Advance the iterator to avoid OOB when the product match
170
+ // happens at the last element.
171
+ iterationRange.rightIdx ++;
172
+ while (iterationRange.isInRange (sourceShapeAsRange) &&
173
+ sourceShape[iterationRange.rightIdx ] == 1 ) {
174
+ resultRange = iterationRange;
175
+ iterationRange.rightIdx ++;
176
+ }
177
+ }
178
+ return *resultRange;
174
179
}
175
180
176
181
// / Attempts to find a valid collapsing reassociation of `sourceShape` into
0 commit comments