@@ -168,12 +168,12 @@ SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
168
168
return blocks;
169
169
}
170
170
171
- // / Computes the common ancestor region of all operations in `ops`. Remembers
172
- // / all the traversed regions in `traversedRegions`.
173
- static Region *findCommonParentRegion (const SetVector<Operation *> &ops,
174
- DenseSet<Region *> &traversedRegions) {
171
+ // / Computes the closest common ancestor region of all operations in `ops`.
172
+ // / Remembers all the traversed regions in `traversedRegions`.
173
+ static Region *findCommonAncestorRegion (const SetVector<Operation *> &ops,
174
+ DenseSet<Region *> &traversedRegions) {
175
175
// Map to count the number of times a region was encountered.
176
- llvm:: DenseMap<Region *, size_t > regionCounts;
176
+ DenseMap<Region *, size_t > regionCounts;
177
177
size_t expectedCount = ops.size ();
178
178
179
179
// Walk the region tree for each operation towards the root and add to the
@@ -182,10 +182,8 @@ static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
182
182
for (Operation *op : ops) {
183
183
Region *current = op->getParentRegion ();
184
184
while (current) {
185
- // Insert or get the count.
186
- auto it = regionCounts.try_emplace (current, 0 ).first ;
187
- size_t count = ++it->getSecond ();
188
- if (count == expectedCount) {
185
+ // Insert or update the count and compare it.
186
+ if (++regionCounts[current] == expectedCount) {
189
187
res = current;
190
188
break ;
191
189
}
@@ -197,11 +195,11 @@ static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
197
195
return res;
198
196
}
199
197
200
- // / Topologically traverses `region` and insers all encountered operations in
198
+ // / Topologically traverses `region` and inserts all encountered operations in
201
199
// / `toSort` into the result. Recursively traverses regions when they are
202
200
// / present in `relevantRegions`.
203
201
static void topoSortRegion (Region ®ion,
204
- const DenseSet<Region *> &relevantRegions ,
202
+ const DenseSet<Region *> &ancestorRegions ,
205
203
const SetVector<Operation *> &toSort,
206
204
SetVector<Operation *> &result) {
207
205
SetVector<Block *> sortedBlocks = getBlocksSortedByDominance (region);
@@ -211,9 +209,9 @@ static void topoSortRegion(Region ®ion,
211
209
result.insert (&op);
212
210
for (Region &subRegion : op.getRegions ()) {
213
211
// Skip regions that do not contain operations from `toSort`.
214
- if (!relevantRegions .contains (®ion))
212
+ if (!ancestorRegions .contains (®ion))
215
213
continue ;
216
- topoSortRegion (subRegion, relevantRegions , toSort, result);
214
+ topoSortRegion (subRegion, ancestorRegions , toSort, result);
217
215
}
218
216
}
219
217
}
@@ -224,19 +222,15 @@ mlir::topologicalSort(const SetVector<Operation *> &toSort) {
224
222
if (toSort.size () <= 1 )
225
223
return toSort;
226
224
227
- assert (llvm::all_of (toSort,
228
- [&](Operation *op) { return toSort.count (op) == 1 ; }) &&
229
- " expected only unique set entries" );
230
-
231
225
// First, find the root region to start the recursive traversal through the
232
226
// IR.
233
- DenseSet<Region *> relevantRegions ;
234
- Region *rootRegion = findCommonParentRegion (toSort, relevantRegions );
227
+ DenseSet<Region *> ancestorRegions ;
228
+ Region *rootRegion = findCommonAncestorRegion (toSort, ancestorRegions );
235
229
assert (rootRegion && " expected all ops to have a common ancestor" );
236
230
237
231
// Sort all element in `toSort` by recursively traversing the IR.
238
232
SetVector<Operation *> result;
239
- topoSortRegion (*rootRegion, relevantRegions , toSort, result);
233
+ topoSortRegion (*rootRegion, ancestorRegions , toSort, result);
240
234
assert (result.size () == toSort.size () &&
241
235
" expected all operations to be present in the result" );
242
236
return result;
0 commit comments