|
11 | 11 | //===----------------------------------------------------------------------===//
|
12 | 12 |
|
13 | 13 | #include "mlir/Analysis/SliceAnalysis.h"
|
| 14 | +#include "mlir/IR/Block.h" |
14 | 15 | #include "mlir/IR/BuiltinOps.h"
|
15 | 16 | #include "mlir/IR/Operation.h"
|
| 17 | +#include "mlir/IR/RegionGraphTraits.h" |
16 | 18 | #include "mlir/Interfaces/SideEffectInterfaces.h"
|
17 | 19 | #include "mlir/Support/LLVM.h"
|
| 20 | +#include "llvm/ADT/PostOrderIterator.h" |
18 | 21 | #include "llvm/ADT/SetVector.h"
|
19 | 22 | #include "llvm/ADT/SmallPtrSet.h"
|
20 | 23 |
|
@@ -164,60 +167,95 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
|
164 | 167 | return topologicalSort(slice);
|
165 | 168 | }
|
166 | 169 |
|
167 |
| -namespace { |
168 |
| -/// DFS post-order implementation that maintains a global count to work across |
169 |
| -/// multiple invocations, to help implement topological sort on multi-root DAGs. |
170 |
| -/// We traverse all operations but only record the ones that appear in |
171 |
| -/// `toSort` for the final result. |
172 |
| -struct DFSState { |
173 |
| - DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {} |
174 |
| - const SetVector<Operation *> &toSort; |
175 |
| - SmallVector<Operation *, 16> topologicalCounts; |
176 |
| - DenseSet<Operation *> seen; |
177 |
| -}; |
178 |
| -} // namespace |
179 |
| - |
180 |
| -static void dfsPostorder(Operation *root, DFSState *state) { |
181 |
| - SmallVector<Operation *> queue(1, root); |
182 |
| - std::vector<Operation *> ops; |
183 |
| - while (!queue.empty()) { |
184 |
| - Operation *current = queue.pop_back_val(); |
185 |
| - ops.push_back(current); |
186 |
| - for (Operation *op : current->getUsers()) |
187 |
| - queue.push_back(op); |
188 |
| - for (Region ®ion : current->getRegions()) { |
189 |
| - for (Operation &op : region.getOps()) |
190 |
| - queue.push_back(&op); |
| 170 | +/// TODO: deduplicate |
| 171 | +static SetVector<Block *> getTopologicallySortedBlocks(Region ®ion) { |
| 172 | + // For each block that has not been visited yet (i.e. that has no |
| 173 | + // predecessors), add it to the list as well as its successors. |
| 174 | + SetVector<Block *> blocks; |
| 175 | + for (Block &b : region) { |
| 176 | + if (blocks.count(&b) == 0) { |
| 177 | + llvm::ReversePostOrderTraversal<Block *> traversal(&b); |
| 178 | + blocks.insert(traversal.begin(), traversal.end()); |
191 | 179 | }
|
192 | 180 | }
|
| 181 | + assert(blocks.size() == region.getBlocks().size() && |
| 182 | + "some blocks are not sorted"); |
193 | 183 |
|
194 |
| - for (Operation *op : llvm::reverse(ops)) { |
195 |
| - if (state->seen.insert(op).second && state->toSort.count(op) > 0) |
196 |
| - state->topologicalCounts.push_back(op); |
| 184 | + return blocks; |
| 185 | +} |
| 186 | + |
| 187 | +/// Computes the common ancestor region of all operations in `ops`. Remembers |
| 188 | +/// all the traversed regions in `traversedRegions`. |
| 189 | +static Region *findCommonParentRegion(const SetVector<Operation *> &ops, |
| 190 | + DenseSet<Region *> &traversedRegions) { |
| 191 | + // Map to count the number of times a region was encountered. |
| 192 | + llvm::DenseMap<Region *, size_t> regionCounts; |
| 193 | + size_t expectedCount = ops.size(); |
| 194 | + |
| 195 | + // Walk the region tree for each operation towards the root and add to the |
| 196 | + // region count. |
| 197 | + Region *res = nullptr; |
| 198 | + for (Operation *op : ops) { |
| 199 | + Region *current = op->getParentRegion(); |
| 200 | + while (current) { |
| 201 | + // Insert or get the count. |
| 202 | + auto it = regionCounts.try_emplace(current, 0).first; |
| 203 | + size_t count = ++it->getSecond(); |
| 204 | + if (count == expectedCount) { |
| 205 | + res = current; |
| 206 | + break; |
| 207 | + } |
| 208 | + current = current->getParentRegion(); |
| 209 | + } |
| 210 | + } |
| 211 | + auto firstRange = llvm::make_first_range(regionCounts); |
| 212 | + traversedRegions.insert(firstRange.begin(), firstRange.end()); |
| 213 | + return res; |
| 214 | +} |
| 215 | + |
| 216 | +/// Topologically traverses `region` and insers all encountered operations in |
| 217 | +/// `toSort` into the result. Recursively traverses regions when they are |
| 218 | +/// present in `relevantRegions`. |
| 219 | +static void topoSortRegion(Region ®ion, |
| 220 | + const DenseSet<Region *> &relevantRegions, |
| 221 | + const SetVector<Operation *> &toSort, |
| 222 | + SetVector<Operation *> &result) { |
| 223 | + SetVector<Block *> sortedBlocks = getTopologicallySortedBlocks(region); |
| 224 | + for (Block *block : sortedBlocks) { |
| 225 | + for (Operation &op : *block) { |
| 226 | + if (toSort.contains(&op)) |
| 227 | + result.insert(&op); |
| 228 | + for (Region &subRegion : op.getRegions()) { |
| 229 | + // Skip regions that do not contain operations from `toSort`. |
| 230 | + if (!relevantRegions.contains(®ion)) |
| 231 | + continue; |
| 232 | + topoSortRegion(subRegion, relevantRegions, toSort, result); |
| 233 | + } |
| 234 | + } |
197 | 235 | }
|
198 | 236 | }
|
199 | 237 |
|
200 | 238 | SetVector<Operation *>
|
201 | 239 | mlir::topologicalSort(const SetVector<Operation *> &toSort) {
|
202 |
| - if (toSort.empty()) { |
| 240 | + if (toSort.size() <= 1) |
203 | 241 | return toSort;
|
204 |
| - } |
205 | 242 |
|
206 |
| - // Run from each root with global count and `seen` set. |
207 |
| - DFSState state(toSort); |
208 |
| - for (auto *s : toSort) { |
209 |
| - assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); |
210 |
| - dfsPostorder(s, &state); |
211 |
| - } |
212 |
| - |
213 |
| - // Reorder and return. |
214 |
| - SetVector<Operation *> res; |
215 |
| - for (auto it = state.topologicalCounts.rbegin(), |
216 |
| - eit = state.topologicalCounts.rend(); |
217 |
| - it != eit; ++it) { |
218 |
| - res.insert(*it); |
219 |
| - } |
220 |
| - return res; |
| 243 | + assert(llvm::all_of(toSort, |
| 244 | + [&](Operation *op) { return toSort.count(op) == 1; }) && |
| 245 | + "expected only unique set entries"); |
| 246 | + |
| 247 | + // First, find the root region to start the recursive traversal through the |
| 248 | + // IR. |
| 249 | + DenseSet<Region *> relevantRegions; |
| 250 | + Region *rootRegion = findCommonParentRegion(toSort, relevantRegions); |
| 251 | + assert(rootRegion && "expected all ops to have a common ancestor"); |
| 252 | + |
| 253 | + // Sort all element in `toSort` by recursively traversing the IR. |
| 254 | + SetVector<Operation *> result; |
| 255 | + topoSortRegion(*rootRegion, relevantRegions, toSort, result); |
| 256 | + assert(result.size() == toSort.size() && |
| 257 | + "expected all operations to be present in the result"); |
| 258 | + return result; |
221 | 259 | }
|
222 | 260 |
|
223 | 261 | /// Returns true if `value` (transitively) depends on iteration-carried values
|
|
0 commit comments